forked from MirrorHub/synapse
Add requesting user id parameter to key claim methods in TransportLayerClient
(#15663)
This commit is contained in:
parent
ca5c4be921
commit
8839b6c2f8
6 changed files with 39 additions and 11 deletions
1
changelog.d/15663.misc
Normal file
1
changelog.d/15663.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add requesting user id parameter to key claim methods in `TransportLayerClient`.
|
|
@ -236,6 +236,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
async def claim_client_keys(
|
async def claim_client_keys(
|
||||||
self,
|
self,
|
||||||
|
user: UserID,
|
||||||
destination: str,
|
destination: str,
|
||||||
query: Dict[str, Dict[str, Dict[str, int]]],
|
query: Dict[str, Dict[str, Dict[str, int]]],
|
||||||
timeout: Optional[int],
|
timeout: Optional[int],
|
||||||
|
@ -243,6 +244,7 @@ class FederationClient(FederationBase):
|
||||||
"""Claims one-time keys for a device hosted on a remote server.
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
user: The user id of the requesting user
|
||||||
destination: Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
content: The query content.
|
content: The query content.
|
||||||
|
|
||||||
|
@ -279,7 +281,7 @@ class FederationClient(FederationBase):
|
||||||
if use_unstable:
|
if use_unstable:
|
||||||
try:
|
try:
|
||||||
return await self.transport_layer.claim_client_keys_unstable(
|
return await self.transport_layer.claim_client_keys_unstable(
|
||||||
destination, unstable_content, timeout
|
user, destination, unstable_content, timeout
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
# If an error is received that is due to an unrecognised endpoint,
|
# If an error is received that is due to an unrecognised endpoint,
|
||||||
|
@ -295,7 +297,7 @@ class FederationClient(FederationBase):
|
||||||
logger.debug("Skipping unstable claim client keys API")
|
logger.debug("Skipping unstable claim client keys API")
|
||||||
|
|
||||||
return await self.transport_layer.claim_client_keys(
|
return await self.transport_layer.claim_client_keys(
|
||||||
destination, content, timeout
|
user, destination, content, timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
|
|
|
@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.federation.units import Transaction
|
from synapse.federation.units import Transaction
|
||||||
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
|
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import ExceptionBundle
|
from synapse.util import ExceptionBundle
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -630,7 +630,11 @@ class TransportLayerClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def claim_client_keys(
|
async def claim_client_keys(
|
||||||
self, destination: str, query_content: JsonDict, timeout: Optional[int]
|
self,
|
||||||
|
user: UserID,
|
||||||
|
destination: str,
|
||||||
|
query_content: JsonDict,
|
||||||
|
timeout: Optional[int],
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||||
|
|
||||||
|
@ -655,6 +659,7 @@ class TransportLayerClient:
|
||||||
}
|
}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
user: the user_id of the requesting user
|
||||||
destination: The server to query.
|
destination: The server to query.
|
||||||
query_content: The user ids to query.
|
query_content: The user ids to query.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -671,7 +676,11 @@ class TransportLayerClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def claim_client_keys_unstable(
|
async def claim_client_keys_unstable(
|
||||||
self, destination: str, query_content: JsonDict, timeout: Optional[int]
|
self,
|
||||||
|
user: UserID,
|
||||||
|
destination: str,
|
||||||
|
query_content: JsonDict,
|
||||||
|
timeout: Optional[int],
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||||
|
|
||||||
|
@ -696,6 +705,7 @@ class TransportLayerClient:
|
||||||
}
|
}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
user: the user_id of the requesting user
|
||||||
destination: The server to query.
|
destination: The server to query.
|
||||||
query_content: The user ids to query.
|
query_content: The user ids to query.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -661,6 +661,7 @@ class E2eKeysHandler:
|
||||||
async def claim_one_time_keys(
|
async def claim_one_time_keys(
|
||||||
self,
|
self,
|
||||||
query: Dict[str, Dict[str, Dict[str, int]]],
|
query: Dict[str, Dict[str, Dict[str, int]]],
|
||||||
|
user: UserID,
|
||||||
timeout: Optional[int],
|
timeout: Optional[int],
|
||||||
always_include_fallback_keys: bool,
|
always_include_fallback_keys: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
@ -703,7 +704,7 @@ class E2eKeysHandler:
|
||||||
device_keys = remote_queries[destination]
|
device_keys = remote_queries[destination]
|
||||||
try:
|
try:
|
||||||
remote_result = await self.federation.claim_client_keys(
|
remote_result = await self.federation.claim_client_keys(
|
||||||
destination, device_keys, timeout=timeout
|
user, destination, device_keys, timeout=timeout
|
||||||
)
|
)
|
||||||
for user_id, keys in remote_result["one_time_keys"].items():
|
for user_id, keys in remote_result["one_time_keys"].items():
|
||||||
if user_id in device_keys:
|
if user_id in device_keys:
|
||||||
|
|
|
@ -287,7 +287,7 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -298,7 +298,7 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
|
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
|
||||||
|
|
||||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||||
query, timeout, always_include_fallback_keys=False
|
query, requester.user, timeout, always_include_fallback_keys=False
|
||||||
)
|
)
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
|
||||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -346,7 +346,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
|
||||||
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
|
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
|
||||||
|
|
||||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||||
query, timeout, always_include_fallback_keys=True
|
query, requester.user, timeout, always_include_fallback_keys=True
|
||||||
)
|
)
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers.device import DeviceHandler
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -45,6 +45,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = hs.get_e2e_keys_handler()
|
self.handler = hs.get_e2e_keys_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
|
||||||
|
|
||||||
def test_query_local_devices_no_devices(self) -> None:
|
def test_query_local_devices_no_devices(self) -> None:
|
||||||
"""If the user has no devices, we expect an empty list."""
|
"""If the user has no devices, we expect an empty list."""
|
||||||
|
@ -161,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
res2 = self.get_success(
|
res2 = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -206,6 +208,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -225,6 +228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -274,6 +278,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -286,6 +291,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -307,6 +313,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -348,6 +355,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
@ -370,6 +378,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
@ -1080,6 +1089,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
|
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=False,
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
|
@ -1125,6 +1135,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id_1: {"alg1": 1}}},
|
{local_user: {device_id_1: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
@ -1169,6 +1180,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id_1: {"alg1": 1}}},
|
{local_user: {device_id_1: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
@ -1202,6 +1214,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id_1: {"alg1": 1}}},
|
{local_user: {device_id_1: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
@ -1229,6 +1242,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id_1: {"alg1": 1}}},
|
{local_user: {device_id_1: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
always_include_fallback_keys=True,
|
always_include_fallback_keys=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue