mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 19:13:51 +01:00
Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to claim keys. This adds an unstable endpoint for `/keys/claim` which always returns fallback keys in addition to one-time-keys. The fallback key(s) are not marked as "used" unless there are no corresponding OTKs. This is currently defined in MSC3983 (although likely to be split out to a separate MSC). The endpoint shape may change or be requested differently (i.e. a keyword parameter on the current endpoint), but the core logic should be reasonable.
This commit is contained in:
parent
b39b02c26e
commit
8e9739449d
9 changed files with 371 additions and 29 deletions
1
changelog.d/15462.misc
Normal file
1
changelog.d/15462.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
|
|
@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def on_claim_client_keys(
|
async def on_claim_client_keys(
|
||||||
self, origin: str, content: JsonDict
|
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
query = []
|
query = []
|
||||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||||
|
@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
|
||||||
query.append((user_id, device_id, algorithm))
|
query.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
|
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||||
|
query, always_include_fallback_keys=always_include_fallback_keys
|
||||||
|
)
|
||||||
|
|
||||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
|
@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
|
||||||
from synapse.federation.transport.server.federation import (
|
from synapse.federation.transport.server.federation import (
|
||||||
FEDERATION_SERVLET_CLASSES,
|
FEDERATION_SERVLET_CLASSES,
|
||||||
FederationAccountStatusServlet,
|
FederationAccountStatusServlet,
|
||||||
|
FederationUnstableClientKeysClaimServlet,
|
||||||
)
|
)
|
||||||
from synapse.http.server import HttpServer, JsonResource
|
from synapse.http.server import HttpServer, JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
|
@ -298,6 +299,11 @@ def register_servlets(
|
||||||
and not hs.config.experimental.msc3720_enabled
|
and not hs.config.experimental.msc3720_enabled
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
if (
|
||||||
|
servletclass == FederationUnstableClientKeysClaimServlet
|
||||||
|
and not hs.config.experimental.msc3983_appservice_otk_claims
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
servletclass(
|
servletclass(
|
||||||
hs=hs,
|
hs=hs,
|
||||||
|
|
|
@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
response = await self.handler.on_claim_client_keys(origin, content)
|
response = await self.handler.on_claim_client_keys(
|
||||||
|
origin, content, always_include_fallback_keys=False
|
||||||
|
)
|
||||||
|
return 200, response
|
||||||
|
|
||||||
|
|
||||||
|
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
|
||||||
|
"""
|
||||||
|
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
|
||||||
|
always includes fallback keys in the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PREFIX = FEDERATION_UNSTABLE_PREFIX
|
||||||
|
PATH = "/user/keys/claim"
|
||||||
|
CATEGORY = "Federation requests"
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
response = await self.handler.on_claim_client_keys(
|
||||||
|
origin, content, always_include_fallback_keys=True
|
||||||
|
)
|
||||||
return 200, response
|
return 200, response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -842,9 +842,7 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
async def claim_e2e_one_time_keys(
|
async def claim_e2e_one_time_keys(
|
||||||
self, query: Iterable[Tuple[str, str, str]]
|
self, query: Iterable[Tuple[str, str, str]]
|
||||||
) -> Tuple[
|
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||||
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
|
|
||||||
]:
|
|
||||||
"""Claim one time keys from application services.
|
"""Claim one time keys from application services.
|
||||||
|
|
||||||
Users which are exclusively owned by an application service are sent a
|
Users which are exclusively owned by an application service are sent a
|
||||||
|
@ -856,7 +854,7 @@ class ApplicationServicesHandler:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||||
|
|
||||||
A copy of the input which has not been fulfilled (either because
|
A copy of the input which has not been fulfilled (either because
|
||||||
they are not appservice users or the appservice does not support
|
they are not appservice users or the appservice does not support
|
||||||
|
@ -897,12 +895,11 @@ class ApplicationServicesHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patch together the results -- they are all independent (since they
|
# Patch together the results -- they are all independent (since they
|
||||||
# require exclusive control over the users). They get returned as a list
|
# require exclusive control over the users, which is the outermost key).
|
||||||
# and the caller combines them.
|
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
|
|
||||||
for success, result in results:
|
for success, result in results:
|
||||||
if success:
|
if success:
|
||||||
claimed_keys.append(result[0])
|
claimed_keys.update(result[0])
|
||||||
missing.extend(result[1])
|
missing.extend(result[1])
|
||||||
|
|
||||||
return claimed_keys, missing
|
return claimed_keys, missing
|
||||||
|
|
|
@ -563,7 +563,9 @@ class E2eKeysHandler:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def claim_local_one_time_keys(
|
async def claim_local_one_time_keys(
|
||||||
self, local_query: List[Tuple[str, str, str]]
|
self,
|
||||||
|
local_query: List[Tuple[str, str, str]],
|
||||||
|
always_include_fallback_keys: bool,
|
||||||
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||||
"""Claim one time keys for local users.
|
"""Claim one time keys for local users.
|
||||||
|
|
||||||
|
@ -573,6 +575,7 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||||
|
always_include_fallback_keys: True to always include fallback keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||||
|
@ -583,24 +586,73 @@ class E2eKeysHandler:
|
||||||
# If the application services have not provided any keys via the C-S
|
# If the application services have not provided any keys via the C-S
|
||||||
# API, query it directly for one-time keys.
|
# API, query it directly for one-time keys.
|
||||||
if self._query_appservices_for_otks:
|
if self._query_appservices_for_otks:
|
||||||
|
# TODO Should this query for fallback keys of uploaded OTKs if
|
||||||
|
# always_include_fallback_keys is True? The MSC is ambiguous.
|
||||||
(
|
(
|
||||||
appservice_results,
|
appservice_results,
|
||||||
not_found,
|
not_found,
|
||||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||||
else:
|
else:
|
||||||
appservice_results = []
|
appservice_results = {}
|
||||||
|
|
||||||
|
# Calculate which user ID / device ID / algorithm tuples to get fallback
|
||||||
|
# keys for. This can be either only missing results *or* all results
|
||||||
|
# (which don't already have a fallback key).
|
||||||
|
if always_include_fallback_keys:
|
||||||
|
# Build the fallback query as any part of the original query where
|
||||||
|
# the appservice didn't respond with a fallback key.
|
||||||
|
fallback_query = []
|
||||||
|
|
||||||
|
# Iterate each item in the original query and search the results
|
||||||
|
# from the appservice for that user ID / device ID. If it is found,
|
||||||
|
# check if any of the keys match the requested algorithm & are a
|
||||||
|
# fallback key.
|
||||||
|
for user_id, device_id, algorithm in local_query:
|
||||||
|
# Check if the appservice responded for this query.
|
||||||
|
as_result = appservice_results.get(user_id, {}).get(device_id, {})
|
||||||
|
found_otk = False
|
||||||
|
for key_id, key_json in as_result.items():
|
||||||
|
if key_id.startswith(f"{algorithm}:"):
|
||||||
|
# A OTK or fallback key was found for this query.
|
||||||
|
found_otk = True
|
||||||
|
# A fallback key was found for this query, no need to
|
||||||
|
# query further.
|
||||||
|
if key_json.get("fallback", False):
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No fallback key was found from appservices, query for it.
|
||||||
|
# Only mark the fallback key as used if no OTK was found
|
||||||
|
# (from either the database or appservices).
|
||||||
|
mark_as_used = not found_otk and not any(
|
||||||
|
key_id.startswith(f"{algorithm}:")
|
||||||
|
for key_id in otk_results.get(user_id, {})
|
||||||
|
.get(device_id, {})
|
||||||
|
.keys()
|
||||||
|
)
|
||||||
|
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# All fallback keys get marked as used.
|
||||||
|
fallback_query = [
|
||||||
|
(user_id, device_id, algorithm, True)
|
||||||
|
for user_id, device_id, algorithm in not_found
|
||||||
|
]
|
||||||
|
|
||||||
# For each user that does not have a one-time keys available, see if
|
# For each user that does not have a one-time keys available, see if
|
||||||
# there is a fallback key.
|
# there is a fallback key.
|
||||||
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
|
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||||
|
|
||||||
# Return the results in order, each item from the input query should
|
# Return the results in order, each item from the input query should
|
||||||
# only appear once in the combined list.
|
# only appear once in the combined list.
|
||||||
return (otk_results, *appservice_results, fallback_results)
|
return (otk_results, appservice_results, fallback_results)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def claim_one_time_keys(
|
async def claim_one_time_keys(
|
||||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
self,
|
||||||
|
query: Dict[str, Dict[str, Dict[str, str]]],
|
||||||
|
timeout: Optional[int],
|
||||||
|
always_include_fallback_keys: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
local_query: List[Tuple[str, str, str]] = []
|
local_query: List[Tuple[str, str, str]] = []
|
||||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||||
|
@ -617,7 +669,9 @@ class E2eKeysHandler:
|
||||||
set_tag("local_key_query", str(local_query))
|
set_tag("local_key_query", str(local_query))
|
||||||
set_tag("remote_key_query", str(remote_queries))
|
set_tag("remote_key_query", str(remote_queries))
|
||||||
|
|
||||||
results = await self.claim_local_one_time_keys(local_query)
|
results = await self.claim_local_one_time_keys(
|
||||||
|
local_query, always_include_fallback_keys
|
||||||
|
)
|
||||||
|
|
||||||
# A map of user ID -> device ID -> key ID -> key.
|
# A map of user ID -> device ID -> key ID -> key.
|
||||||
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
|
@ -625,7 +679,9 @@ class E2eKeysHandler:
|
||||||
for user_id, device_keys in result.items():
|
for user_id, device_keys in result.items():
|
||||||
for device_id, keys in device_keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
for key_id, key in keys.items():
|
for key_id, key in keys.items():
|
||||||
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
|
json_result.setdefault(user_id, {}).setdefault(
|
||||||
|
device_id, {}
|
||||||
|
).update({key_id: key})
|
||||||
|
|
||||||
# Remote failures.
|
# Remote failures.
|
||||||
failures: Dict[str, JsonDict] = {}
|
failures: Dict[str, JsonDict] = {}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||||
|
@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
|
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||||
|
body, timeout, always_include_fallback_keys=False
|
||||||
|
)
|
||||||
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
|
class UnstableOneTimeKeyServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
|
||||||
|
fallback keys in the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
|
||||||
|
CATEGORY = "Encryption requests"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||||
|
|
||||||
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||||
|
body, timeout, always_include_fallback_keys=True
|
||||||
|
)
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
|
@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
KeyQueryServlet(hs).register(http_server)
|
KeyQueryServlet(hs).register(http_server)
|
||||||
KeyChangesServlet(hs).register(http_server)
|
KeyChangesServlet(hs).register(http_server)
|
||||||
OneTimeKeyServlet(hs).register(http_server)
|
OneTimeKeyServlet(hs).register(http_server)
|
||||||
|
if hs.config.experimental.msc3983_appservice_otk_claims:
|
||||||
|
UnstableOneTimeKeyServlet(hs).register(http_server)
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
SigningKeyUploadServlet(hs).register(http_server)
|
SigningKeyUploadServlet(hs).register(http_server)
|
||||||
SignaturesUploadServlet(hs).register(http_server)
|
SignaturesUploadServlet(hs).register(http_server)
|
||||||
|
|
|
@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
return results, missing
|
return results, missing
|
||||||
|
|
||||||
async def claim_e2e_fallback_keys(
|
async def claim_e2e_fallback_keys(
|
||||||
self, query_list: Iterable[Tuple[str, str, str]]
|
self, query_list: Iterable[Tuple[str, str, str, bool]]
|
||||||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
|
||||||
"""Take a list of fallback keys out of the database.
|
"""Take a list of fallback keys out of the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
query_list: An iterable of tuples of
|
||||||
|
(user ID, device ID, algorithm, whether the key should be marked as used).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
A map of user ID -> a map device ID -> a map of key ID -> JSON.
|
||||||
"""
|
"""
|
||||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
for user_id, device_id, algorithm in query_list:
|
for user_id, device_id, algorithm, mark_as_used in query_list:
|
||||||
row = await self.db_pool.simple_select_one(
|
row = await self.db_pool.simple_select_one(
|
||||||
table="e2e_fallback_keys_json",
|
table="e2e_fallback_keys_json",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
|
@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
used = row["used"]
|
used = row["used"]
|
||||||
|
|
||||||
# Mark fallback key as used if not already.
|
# Mark fallback key as used if not already.
|
||||||
if not used:
|
if not used and mark_as_used:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="e2e_fallback_keys_json",
|
table="e2e_fallback_keys_json",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
|
|
|
@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
res2 = self.get_success(
|
res2 = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# key
|
# key
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# claiming an OTK again should return the same fallback key
|
# claiming an OTK again should return the same fallback key
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
claim_res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_fallback_key_always_returned(self) -> None:
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||||
|
otk = {"alg1:k2": "key2"}
|
||||||
|
|
||||||
|
# we shouldn't have any unused fallback keys yet
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(res, [])
|
||||||
|
|
||||||
|
# Upload a OTK & fallback key.
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user,
|
||||||
|
device_id,
|
||||||
|
{"one_time_keys": otk, "fallback_keys": fallback_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# we should now have an unused alg1 key
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# Claiming an OTK and requesting to always return the fallback key should
|
||||||
|
# return both.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should not mark the key as used.
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# Claiming an OTK again should return only the fallback key.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# And mark it as used.
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, [])
|
||||||
|
|
||||||
def test_replace_master_key(self) -> None:
|
def test_replace_master_key(self) -> None:
|
||||||
"""uploading a new signing key should make the old signing key unavailable"""
|
"""uploading a new signing key should make the old signing key unavailable"""
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
timeout=None,
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
|
||||||
|
def test_query_appservice_with_fallback(self) -> None:
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id_1 = "xyz"
|
||||||
|
fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
|
||||||
|
otk = {"alg1:k2": {"desc": "key2"}}
|
||||||
|
as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
|
||||||
|
as_otk = {"alg1:k4": {"desc": "key4"}}
|
||||||
|
|
||||||
|
# Inject an appservice interested in this user.
|
||||||
|
appservice = ApplicationService(
|
||||||
|
token="i_am_an_app_service",
|
||||||
|
id="1234",
|
||||||
|
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
|
||||||
|
# Note: this user does not have to match the regex above
|
||||||
|
sender="@as_main:test",
|
||||||
|
)
|
||||||
|
self.hs.get_datastores().main.services_cache = [appservice]
|
||||||
|
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
|
||||||
|
[appservice]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup a response.
|
||||||
|
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||||
|
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim OTKs, which will ask the appservice and do nothing else.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now upload a fallback key.
|
||||||
|
res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(res, [])
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user,
|
||||||
|
device_id_1,
|
||||||
|
{"fallback_keys": fallback_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# we should now have an unused alg1 key
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# The appservice will return only the OTK.
|
||||||
|
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||||
|
({local_user: {device_id_1: as_otk}}, [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim OTKs, which should return the OTK from the appservice and the
|
||||||
|
# uploaded fallback key.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {device_id_1: {**as_otk, **fallback_key}}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# But the fallback key should not be marked as used.
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# Now upload a OTK.
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user,
|
||||||
|
device_id_1,
|
||||||
|
{"one_time_keys": otk},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim OTKs, which will return information only from the database.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# But the fallback key should not be marked as used.
|
||||||
|
fallback_res = self.get_success(
|
||||||
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
|
||||||
|
)
|
||||||
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
|
# Finally, return only the fallback key from the appservice.
|
||||||
|
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||||
|
({local_user: {device_id_1: as_fallback_key}}, [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim OTKs, which will return only the fallback key from the database.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
claim_res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {local_user: {device_id_1: as_fallback_key}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
|
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
|
||||||
def test_query_local_devices_appservice(self) -> None:
|
def test_query_local_devices_appservice(self) -> None:
|
||||||
"""Test that querying of appservices for keys overrides responses from the database."""
|
"""Test that querying of appservices for keys overrides responses from the database."""
|
||||||
|
|
Loading…
Reference in a new issue