Add requesting user id parameter to key claim methods in TransportLayerClient (#15663)

This commit is contained in:
Shay 2023-05-24 13:23:26 -07:00 committed by GitHub
parent ca5c4be921
commit 8839b6c2f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 11 deletions

1
changelog.d/15663.misc Normal file
View file

@ -0,0 +1 @@
Add requesting user id parameter to key claim methods in `TransportLayerClient`.

View file

@ -236,6 +236,7 @@ class FederationClient(FederationBase):
async def claim_client_keys( async def claim_client_keys(
self, self,
user: UserID,
destination: str, destination: str,
query: Dict[str, Dict[str, Dict[str, int]]], query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int], timeout: Optional[int],
@ -243,6 +244,7 @@ class FederationClient(FederationBase):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
user: The user id of the requesting user
destination: Domain name of the remote homeserver destination: Domain name of the remote homeserver
content: The query content. content: The query content.
@ -279,7 +281,7 @@ class FederationClient(FederationBase):
if use_unstable: if use_unstable:
try: try:
return await self.transport_layer.claim_client_keys_unstable( return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout user, destination, unstable_content, timeout
) )
except HttpResponseException as e: except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint, # If an error is received that is due to an unrecognised endpoint,
@ -295,7 +297,7 @@ class FederationClient(FederationBase):
logger.debug("Skipping unstable claim client keys API") logger.debug("Skipping unstable claim client keys API")
return await self.transport_layer.claim_client_keys( return await self.transport_layer.claim_client_keys(
destination, content, timeout user, destination, content, timeout
) )
@trace @trace

View file

@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import ExceptionBundle from synapse.util import ExceptionBundle
if TYPE_CHECKING: if TYPE_CHECKING:
@ -630,7 +630,11 @@ class TransportLayerClient:
) )
async def claim_client_keys( async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: Optional[int] self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict: ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
@ -655,6 +659,7 @@ class TransportLayerClient:
} }
Args: Args:
user: the user_id of the requesting user
destination: The server to query. destination: The server to query.
query_content: The user ids to query. query_content: The user ids to query.
Returns: Returns:
@ -671,7 +676,11 @@ class TransportLayerClient:
) )
async def claim_client_keys_unstable( async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int] self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict: ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
@ -696,6 +705,7 @@ class TransportLayerClient:
} }
Args: Args:
user: the user_id of the requesting user
destination: The server to query. destination: The server to query.
query_content: The user ids to query. query_content: The user ids to query.
Returns: Returns:

View file

@ -661,6 +661,7 @@ class E2eKeysHandler:
async def claim_one_time_keys( async def claim_one_time_keys(
self, self,
query: Dict[str, Dict[str, Dict[str, int]]], query: Dict[str, Dict[str, Dict[str, int]]],
user: UserID,
timeout: Optional[int], timeout: Optional[int],
always_include_fallback_keys: bool, always_include_fallback_keys: bool,
) -> JsonDict: ) -> JsonDict:
@ -703,7 +704,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
remote_result = await self.federation.claim_client_keys( remote_result = await self.federation.claim_client_keys(
destination, device_keys, timeout=timeout user, destination, device_keys, timeout=timeout
) )
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys: if user_id in device_keys:

View file

@ -287,7 +287,7 @@ class OneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -298,7 +298,7 @@ class OneTimeKeyServlet(RestServlet):
query.setdefault(user_id, {})[device_id] = {algorithm: 1} query.setdefault(user_id, {})[device_id] = {algorithm: 1}
result = await self.e2e_keys_handler.claim_one_time_keys( result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=False query, requester.user, timeout, always_include_fallback_keys=False
) )
return 200, result return 200, result
@ -335,7 +335,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -346,7 +346,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
query.setdefault(user_id, {})[device_id] = Counter(algorithms) query.setdefault(user_id, {})[device_id] = Counter(algorithms)
result = await self.e2e_keys_handler.claim_one_time_keys( result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=True query, requester.user, timeout, always_include_fallback_keys=True
) )
return 200, result return 200, result

View file

@ -27,7 +27,7 @@ from synapse.appservice import ApplicationService
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -45,6 +45,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler() self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
def test_query_local_devices_no_devices(self) -> None: def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list.""" """If the user has no devices, we expect an empty list."""
@ -161,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success( res2 = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -206,6 +208,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -225,6 +228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -274,6 +278,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -286,6 +291,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -307,6 +313,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -348,6 +355,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -370,6 +378,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1080,6 +1089,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=False, always_include_fallback_keys=False,
) )
@ -1125,6 +1135,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}}, {local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1169,6 +1180,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}}, {local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1202,6 +1214,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}}, {local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )
@ -1229,6 +1242,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success( claim_res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}}, {local_user: {device_id_1: {"alg1": 1}}},
self.requester,
timeout=None, timeout=None,
always_include_fallback_keys=True, always_include_fallback_keys=True,
) )