mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 11:23:25 +01:00
Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to claim keys. This adds an unstable endpoint for `/keys/claim` which always returns fallback keys in addition to one-time-keys. The fallback key(s) are not marked as "used" unless there are no corresponding OTKs. This is currently defined in MSC3983 (although likely to be split out to a separate MSC). The endpoint shape may change or be requested differently (i.e. a keyword parameter on the current endpoint), but the core logic should be reasonable.
This commit is contained in:
parent
b39b02c26e
commit
8e9739449d
9 changed files with 371 additions and 29 deletions
1
changelog.d/15462.misc
Normal file
1
changelog.d/15462.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
|
|
@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
@trace
|
||||
async def on_claim_client_keys(
|
||||
self, origin: str, content: JsonDict
|
||||
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
|
||||
) -> Dict[str, Any]:
|
||||
query = []
|
||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||
|
@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
|
|||
query.append((user_id, device_id, algorithm))
|
||||
|
||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query, always_include_fallback_keys=always_include_fallback_keys
|
||||
)
|
||||
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for result in results:
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
|
|||
from synapse.federation.transport.server.federation import (
|
||||
FEDERATION_SERVLET_CLASSES,
|
||||
FederationAccountStatusServlet,
|
||||
FederationUnstableClientKeysClaimServlet,
|
||||
)
|
||||
from synapse.http.server import HttpServer, JsonResource
|
||||
from synapse.http.servlet import (
|
||||
|
@ -298,6 +299,11 @@ def register_servlets(
|
|||
and not hs.config.experimental.msc3720_enabled
|
||||
):
|
||||
continue
|
||||
if (
|
||||
servletclass == FederationUnstableClientKeysClaimServlet
|
||||
and not hs.config.experimental.msc3983_appservice_otk_claims
|
||||
):
|
||||
continue
|
||||
|
||||
servletclass(
|
||||
hs=hs,
|
||||
|
|
|
@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
|
|||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
response = await self.handler.on_claim_client_keys(origin, content)
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
||||
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
|
||||
always includes fallback keys in the response.
|
||||
"""
|
||||
|
||||
PREFIX = FEDERATION_UNSTABLE_PREFIX
|
||||
PATH = "/user/keys/claim"
|
||||
CATEGORY = "Federation requests"
|
||||
|
||||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
||||
|
|
|
@ -842,9 +842,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query: Iterable[Tuple[str, str, str]]
|
||||
) -> Tuple[
|
||||
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
|
||||
]:
|
||||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
"""Claim one time keys from application services.
|
||||
|
||||
Users which are exclusively owned by an application service are sent a
|
||||
|
@ -856,7 +854,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
Returns:
|
||||
A tuple of:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
|
||||
A copy of the input which has not been fulfilled (either because
|
||||
they are not appservice users or the appservice does not support
|
||||
|
@ -897,12 +895,11 @@ class ApplicationServicesHandler:
|
|||
)
|
||||
|
||||
# Patch together the results -- they are all independent (since they
|
||||
# require exclusive control over the users). They get returned as a list
|
||||
# and the caller combines them.
|
||||
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
|
||||
# require exclusive control over the users, which is the outermost key).
|
||||
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for success, result in results:
|
||||
if success:
|
||||
claimed_keys.append(result[0])
|
||||
claimed_keys.update(result[0])
|
||||
missing.extend(result[1])
|
||||
|
||||
return claimed_keys, missing
|
||||
|
|
|
@ -563,7 +563,9 @@ class E2eKeysHandler:
|
|||
return ret
|
||||
|
||||
async def claim_local_one_time_keys(
|
||||
self, local_query: List[Tuple[str, str, str]]
|
||||
self,
|
||||
local_query: List[Tuple[str, str, str]],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||
"""Claim one time keys for local users.
|
||||
|
||||
|
@ -573,6 +575,7 @@ class E2eKeysHandler:
|
|||
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
|
@ -583,24 +586,73 @@ class E2eKeysHandler:
|
|||
# If the application services have not provided any keys via the C-S
|
||||
# API, query it directly for one-time keys.
|
||||
if self._query_appservices_for_otks:
|
||||
# TODO Should this query for fallback keys of uploaded OTKs if
|
||||
# always_include_fallback_keys is True? The MSC is ambiguous.
|
||||
(
|
||||
appservice_results,
|
||||
not_found,
|
||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||
else:
|
||||
appservice_results = []
|
||||
appservice_results = {}
|
||||
|
||||
# Calculate which user ID / device ID / algorithm tuples to get fallback
|
||||
# keys for. This can be either only missing results *or* all results
|
||||
# (which don't already have a fallback key).
|
||||
if always_include_fallback_keys:
|
||||
# Build the fallback query as any part of the original query where
|
||||
# the appservice didn't respond with a fallback key.
|
||||
fallback_query = []
|
||||
|
||||
# Iterate each item in the original query and search the results
|
||||
# from the appservice for that user ID / device ID. If it is found,
|
||||
# check if any of the keys match the requested algorithm & are a
|
||||
# fallback key.
|
||||
for user_id, device_id, algorithm in local_query:
|
||||
# Check if the appservice responded for this query.
|
||||
as_result = appservice_results.get(user_id, {}).get(device_id, {})
|
||||
found_otk = False
|
||||
for key_id, key_json in as_result.items():
|
||||
if key_id.startswith(f"{algorithm}:"):
|
||||
# A OTK or fallback key was found for this query.
|
||||
found_otk = True
|
||||
# A fallback key was found for this query, no need to
|
||||
# query further.
|
||||
if key_json.get("fallback", False):
|
||||
break
|
||||
|
||||
else:
|
||||
# No fallback key was found from appservices, query for it.
|
||||
# Only mark the fallback key as used if no OTK was found
|
||||
# (from either the database or appservices).
|
||||
mark_as_used = not found_otk and not any(
|
||||
key_id.startswith(f"{algorithm}:")
|
||||
for key_id in otk_results.get(user_id, {})
|
||||
.get(device_id, {})
|
||||
.keys()
|
||||
)
|
||||
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
|
||||
|
||||
else:
|
||||
# All fallback keys get marked as used.
|
||||
fallback_query = [
|
||||
(user_id, device_id, algorithm, True)
|
||||
for user_id, device_id, algorithm in not_found
|
||||
]
|
||||
|
||||
# For each user that does not have a one-time keys available, see if
|
||||
# there is a fallback key.
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||
|
||||
# Return the results in order, each item from the input query should
|
||||
# only appear once in the combined list.
|
||||
return (otk_results, *appservice_results, fallback_results)
|
||||
return (otk_results, appservice_results, fallback_results)
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
||||
self,
|
||||
query: Dict[str, Dict[str, Dict[str, str]]],
|
||||
timeout: Optional[int],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> JsonDict:
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
@ -617,7 +669,9 @@ class E2eKeysHandler:
|
|||
set_tag("local_key_query", str(local_query))
|
||||
set_tag("remote_key_query", str(remote_queries))
|
||||
|
||||
results = await self.claim_local_one_time_keys(local_query)
|
||||
results = await self.claim_local_one_time_keys(
|
||||
local_query, always_include_fallback_keys
|
||||
)
|
||||
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
|
@ -625,7 +679,9 @@ class E2eKeysHandler:
|
|||
for user_id, device_keys in result.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, key in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
|
||||
json_result.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
).update({key_id: key})
|
||||
|
||||
# Remote failures.
|
||||
failures: Dict[str, JsonDict] = {}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||
|
@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
|
|||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
class UnstableOneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
|
||||
fallback keys in the response.
|
||||
"""
|
||||
|
||||
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
|
||||
CATEGORY = "Encryption requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
|
@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
|||
KeyQueryServlet(hs).register(http_server)
|
||||
KeyChangesServlet(hs).register(http_server)
|
||||
OneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc3983_appservice_otk_claims:
|
||||
UnstableOneTimeKeyServlet(hs).register(http_server)
|
||||
if hs.config.worker.worker_app is None:
|
||||
SigningKeyUploadServlet(hs).register(http_server)
|
||||
SignaturesUploadServlet(hs).register(http_server)
|
||||
|
|
|
@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
return results, missing
|
||||
|
||||
async def claim_e2e_fallback_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str]]
|
||||
self, query_list: Iterable[Tuple[str, str, str, bool]]
|
||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Take a list of fallback keys out of the database.
|
||||
|
||||
Args:
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
query_list: An iterable of tuples of
|
||||
(user ID, device ID, algorithm, whether the key should be marked as used).
|
||||
|
||||
Returns:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||
"""
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
for user_id, device_id, algorithm, mark_as_used in query_list:
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
|
@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
used = row["used"]
|
||||
|
||||
# Mark fallback key as used if not already.
|
||||
if not used:
|
||||
if not used and mark_as_used:
|
||||
await self.db_pool.simple_update_one(
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
|
|
|
@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
res2 = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# claiming an OTK again should return the same fallback key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||
)
|
||||
|
||||
def test_fallback_key_always_returned(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
otk = {"alg1:k2": "key2"}
|
||||
|
||||
# we shouldn't have any unused fallback keys yet
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(res, [])
|
||||
|
||||
# Upload a OTK & fallback key.
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"one_time_keys": otk, "fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
# we should now have an unused alg1 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Claiming an OTK and requesting to always return the fallback key should
|
||||
# return both.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
|
||||
},
|
||||
)
|
||||
|
||||
# This should not mark the key as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Claiming an OTK again should return only the fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||
)
|
||||
|
||||
# And mark it as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||
)
|
||||
self.assertEqual(fallback_res, [])
|
||||
|
||||
def test_replace_master_key(self) -> None:
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
|
@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
}
|
||||
},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
|
@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
|
||||
def test_query_appservice_with_fallback(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id_1 = "xyz"
|
||||
fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
|
||||
otk = {"alg1:k2": {"desc": "key2"}}
|
||||
as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
|
||||
as_otk = {"alg1:k4": {"desc": "key4"}}
|
||||
|
||||
# Inject an appservice interested in this user.
|
||||
appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
sender="@as_main:test",
|
||||
)
|
||||
self.hs.get_datastores().main.services_cache = [appservice]
|
||||
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
|
||||
[appservice]
|
||||
)
|
||||
|
||||
# Setup a response.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which will ask the appservice and do nothing else.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Now upload a fallback key.
|
||||
res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(res, [])
|
||||
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id_1,
|
||||
{"fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
# we should now have an unused alg1 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# The appservice will return only the OTK.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: as_otk}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which should return the OTK from the appservice and the
|
||||
# uploaded fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {
|
||||
local_user: {device_id_1: {**as_otk, **fallback_key}}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# But the fallback key should not be marked as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Now upload a OTK.
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id_1,
|
||||
{"one_time_keys": otk},
|
||||
)
|
||||
)
|
||||
|
||||
# Claim OTKs, which will return information only from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
|
||||
},
|
||||
)
|
||||
|
||||
# But the fallback key should not be marked as used.
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||
)
|
||||
self.assertEqual(fallback_res, ["alg1"])
|
||||
|
||||
# Finally, return only the fallback key from the appservice.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_1: as_fallback_key}}, [])
|
||||
)
|
||||
|
||||
# Claim OTKs, which will return only the fallback key from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
claim_res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id_1: as_fallback_key}},
|
||||
},
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
|
||||
def test_query_local_devices_appservice(self) -> None:
|
||||
"""Test that querying of appservices for keys overrides responses from the database."""
|
||||
|
|
Loading…
Reference in a new issue