Pass the Requester down to the HttpTransactionCache. (#15200)

This commit is contained in:
Quentin Gliech 2023-03-07 17:05:22 +01:00 committed by GitHub
parent 820f02b70b
commit 47bc84dd53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 217 additions and 131 deletions

1
changelog.d/15200.misc Normal file
View file

@ -0,0 +1 @@
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
@ -23,10 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin from synapse.logging.opentracing import set_tag
from synapse.rest.admin._base import admin_patterns from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__, self.__class__.__name__,
) )
async def on_POST( async def _do(
self, request: SynapseRequest, txn_id: Optional[str] = None self,
request: SynapseRequest,
requester: Requester,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_user_is_admin(self.auth, requester)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content")) assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message) event_type = body.get("type", EventTypes.Message)
@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet):
return HTTPStatus.OK, {"event_id": event.event_id} return HTTPStatus.OK, {"event_id": event.event_id}
def on_PUT( async def on_POST(
self,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, None)
async def on_PUT(
self, request: SynapseRequest, txn_id: str self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request( requester = await self.auth.get_user_by_req(request)
request, self.on_POST, request, txn_id set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, txn_id
) )

View file

@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
def on_PUT( async def on_PUT(
self, request: SynapseRequest, txn_id: str self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request) return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester)
async def _do(
self, request: SynapseRequest, requester: Requester
) -> Tuple[int, JsonDict]:
room_id, _, _ = await self._room_creation_handler.create_room( room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request) requester, self.get_room_config(request)
) )
@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet): class RoomStateEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True) register_txn_path(self, PATTERNS, http_server)
async def on_POST( async def _do(
self, self,
request: SynapseRequest, request: SynapseRequest,
requester: Requester,
room_id: str, room_id: str,
event_type: str, event_type: str,
txn_id: Optional[str] = None, txn_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict: JsonDict = { event_dict: JsonDict = {
@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id) set_tag("event_id", event_id)
return 200, {"event_id": event_id} return 200, {"event_id": event_id}
def on_GET( async def on_POST(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str self,
) -> Tuple[int, str]: request: SynapseRequest,
return 200, "Not implemented" room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, event_type, None)
def on_PUT( async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return await self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id request,
requester,
self._do,
request,
requester,
room_id,
event_type,
txn_id,
) )
@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)" PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST( async def _do(
self, self,
request: SynapseRequest, request: SynapseRequest,
requester: Requester,
room_identifier: str, room_identifier: str,
txn_id: Optional[str] = None, txn_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request, allow_empty_body=True) content = parse_json_object_from_request(request, allow_empty_body=True)
# twisted.web.server.Request.args is incorrectly defined as Optional[Any] # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
def on_PUT( async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_identifier, None)
async def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return await self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id request, requester, self._do, request, requester, room_identifier, txn_id
) )
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(TransactionRestServlet): class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True) PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST( async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id) await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {} return 200, {}
def on_PUT( async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
return await self._do(requester, room_id)
async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return await self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id request, requester, self._do, requester, room_id
) )
@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
) )
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST( async def _do(
self, self,
request: SynapseRequest, request: SynapseRequest,
requester: Requester,
room_id: str, room_id: str,
membership_action: str, membership_action: str,
txn_id: Optional[str] = None, txn_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in { if requester.is_guest and membership_action not in {
Membership.JOIN, Membership.JOIN,
Membership.LEAVE, Membership.LEAVE,
@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value return 200, return_value
def on_PUT( async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, membership_action, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return await self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id request,
requester,
self._do,
request,
requester,
room_id,
membership_action,
txn_id,
) )
@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST( async def _do(
self, self,
request: SynapseRequest, request: SynapseRequest,
requester: Requester,
room_id: str, room_id: str,
event_id: str, event_id: str,
txn_id: Optional[str] = None, txn_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id) set_tag("event_id", event_id)
return 200, {"event_id": event_id} return 200, {"event_id": event_id}
def on_PUT( async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, room_id, event_id, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return await self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id request, requester, self._do, request, requester, room_id, event_id, txn_id
) )
@ -1224,7 +1280,6 @@ def register_txn_path(
servlet: RestServlet, servlet: RestServlet,
regex_string: str, regex_string: str,
http_server: HttpServer, http_server: HttpServer,
with_get: bool = False,
) -> None: ) -> None:
"""Registers a transaction-based path. """Registers a transaction-based path.
@ -1236,7 +1291,6 @@ def register_txn_path(
regex_string: The regex string to register. Must NOT have a regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to. trailing $ as this string will be appended to.
http_server: The http_server to register paths with. http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
""" """
on_POST = getattr(servlet, "on_POST", None) on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None) on_PUT = getattr(servlet, "on_PUT", None)
@ -1254,18 +1308,6 @@ def register_txn_path(
on_PUT, on_PUT,
servlet.__class__.__name__, servlet.__class__.__name__,
) )
on_GET = getattr(servlet, "on_GET", None)
if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
on_GET,
servlet.__class__.__name__,
)
class TimestampLookupRestServlet(RestServlet): class TimestampLookupRestServlet(RestServlet):

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.http import servlet from synapse.http import servlet
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict from synapse.types import JsonDict, Requester
from ._base import client_patterns from ._base import client_patterns
@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
def on_PUT( async def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
async def _put(
self, request: SynapseRequest, message_type: str, txn_id: str self, request: SynapseRequest, message_type: str, txn_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request,
requester,
self._put,
request,
requester,
message_type,
)
async def _put(
self,
request: SynapseRequest,
requester: Requester,
message_type: str,
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",)) assert_params_in_dict(content, ("messages",))

View file

@ -15,16 +15,16 @@
"""This module contains logic for storing HTTP PUT transactions. This is used """This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.server import Request from twisted.web.iweb import IRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,53 +41,47 @@ P = ParamSpec("P")
class HttpTransactionCache: class HttpTransactionCache:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
self.transactions: Dict[ self.transactions: Dict[
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
] = {} ] = {}
# Try to clean entries every 30 mins. This means entries will exist # Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins. # for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request: Request) -> str: def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
"""A helper function which returns a transaction key that can be used """A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests. with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user. path and attributes from the requester: the access_token_id for regular users,
the user ID for guest users, and the appservice ID for appservice users.
Args: Args:
request: The incoming request. Must contain an access_token. request: The incoming request.
requester: The requester doing the request.
Returns: Returns:
A transaction key A transaction key
""" """
assert request.path is not None assert request.path is not None
token = self.auth.get_access_token_from_request(request) path: str = request.path.decode("utf8")
return request.path.decode("utf8") + "/" + token if requester.is_guest:
assert requester.user is not None, "Guest requester must have a user ID set"
return (path, "guest", requester.user)
elif requester.app_service is not None:
return (path, "appservice", requester.app_service.id)
else:
assert (
requester.access_token_id is not None
), "Requester must have an access_token_id"
return (path, "user", requester.access_token_id)
def fetch_or_execute_request( def fetch_or_execute_request(
self, self,
request: Request, request: IRequest,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], requester: Requester,
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
return self.fetch_or_execute(
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(
self,
txn_key: str,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args, *args: P.args,
**kwargs: P.kwargs, **kwargs: P.kwargs,
@ -96,14 +90,15 @@ class HttpTransactionCache:
to produce a response for this transaction. to produce a response for this transaction.
Args: Args:
txn_key: A key to ensure idempotency should fetch_or_execute be request:
called again at a later point in time. requester:
fn: A function which returns a tuple of (response_code, response_dict). fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn. *args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn. **kwargs: Keyword arguments to pass to fn.
Returns: Returns:
Deferred which resolves to a tuple of (response_code, response_dict). Deferred which resolves to a tuple of (response_code, response_dict).
""" """
txn_key = self._get_transaction_key(request, requester)
if txn_key in self.transactions: if txn_key in self.transactions:
observable = self.transactions[txn_key][0] observable = self.transactions[txn_key][0]
else: else:

View file

@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.cache = HttpTransactionCache(self.hs) self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
self.mock_key = "foo"
# Here we make sure that we're setting all the fields that HttpTransactionCache
# uses to build the transaction key.
self.mock_request = Mock()
self.mock_request.path = b"/foo/bar"
self.mock_requester = Mock()
self.mock_requester.app_service = None
self.mock_requester.is_guest = False
self.mock_requester.access_token_id = 1234
@defer.inlineCallbacks @defer.inlineCallbacks
def test_executes_given_function( def test_executes_given_function(
self, self,
) -> Generator["defer.Deferred[Any]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute_request(
self.mock_key, cb, "some_arg", keyword="arg" self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
) )
cb.assert_called_once_with("some_arg", keyword="arg") cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]: ) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute_request(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i self.mock_request,
self.mock_requester,
cb,
"some_arg",
keyword="arg",
changing_args=i,
) )
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work # expect only a single call to do the work
@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]: def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1: with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertIs(current_context(), c1) self.assertIs(current_context(), c1)
self.assertEqual(res, (1, {})) self.assertEqual(res, (1, {}))
@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context: with LoggingContext("test") as test_context:
try: try:
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
except Exception as e: except Exception as e:
self.assertEqual(e.args[0], "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context) self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context) self.assertIs(current_context(), test_context)
@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context: with LoggingContext("test") as test_context:
try: try:
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
except Exception as e: except Exception as e:
self.assertEqual(e.args[0], "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context) self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context) self.assertIs(current_context(), test_context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response)) cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# should NOT have cleaned up yet # should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# still using cache # still using cache
cb.assert_called_once_with("an arg") cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS) self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# no longer using cache # no longer using cache
self.assertEqual(cb.call_count, 2) self.assertEqual(cb.call_count, 2)
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")]) self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])