0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 11:23:25 +01:00

Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)

It can be useful to always return the fallback key when attempting to
claim keys. This adds an unstable endpoint for `/keys/claim` which
always returns fallback keys in addition to one-time-keys.

The fallback key(s) are not marked as "used" unless there are no
corresponding OTKs.

This is currently defined in MSC3983 (although likely to be split out
to a separate MSC). The endpoint shape may change or be requested
differently (i.e. a keyword parameter on the current endpoint), but the
core logic should be reasonable.
This commit is contained in:
Patrick Cloke 2023-04-25 13:30:41 -04:00 committed by GitHub
parent b39b02c26e
commit 8e9739449d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 29 deletions

1
changelog.d/15462.misc Normal file
View file

@ -0,0 +1 @@
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.

View file

@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
)
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:

View file

@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
@ -298,6 +299,11 @@ def register_servlets(
and not hs.config.experimental.msc3720_enabled
):
continue
if (
servletclass == FederationUnstableClientKeysClaimServlet
and not hs.config.experimental.msc3983_appservice_otk_claims
):
continue
servletclass(
hs=hs,

View file

@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content)
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
PATH = "/user/keys/claim"
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
)
return 200, response

View file

@ -842,9 +842,7 @@ class ApplicationServicesHandler:
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
@ -856,7 +854,7 @@ class ApplicationServicesHandler:
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
@ -897,12 +895,11 @@ class ApplicationServicesHandler:
)
# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.append(result[0])
claimed_keys.update(result[0])
missing.extend(result[1])
return claimed_keys, missing

View file

@ -563,7 +563,9 @@ class E2eKeysHandler:
return ret
async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
self,
local_query: List[Tuple[str, str, str]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
@ -573,6 +575,7 @@ class E2eKeysHandler:
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@ -583,24 +586,73 @@ class E2eKeysHandler:
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
appservice_results = {}
# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []
# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break
else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
]
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
return (otk_results, appservice_results, fallback_results)
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self,
query: Dict[str, Dict[str, Dict[str, str]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@ -617,7 +669,9 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
results = await self.claim_local_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)
# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
@ -625,7 +679,9 @@ class E2eKeysHandler:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})
# Remote failures.
failures: Dict[str, JsonDict] = {}

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False
)
return 200, result
class UnstableOneTimeKeyServlet(RestServlet):
"""
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
fallback keys in the response.
"""
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True
)
return 200, result
@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)

View file

@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results, missing
async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of
(user ID, device ID, algorithm, whether the key should be marked as used).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
used = row["used"]
# Mark fallback key as used if not already.
if not used:
if not used and mark_as_used:
await self.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={

View file

@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# claiming an OTK again should return the same fallback key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
def test_fallback_key_always_returned(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
# Upload a OTK & fallback key.
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"one_time_keys": otk, "fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK and requesting to always return the fallback key should
# return both.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
},
)
# This should not mark the key as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK again should return only the fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# And mark it as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, [])
def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
@override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
def test_query_appservice_with_fallback(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id_1 = "xyz"
fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
otk = {"alg1:k2": {"desc": "key2"}}
as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
as_otk = {"alg1:k4": {"desc": "key4"}}
# Inject an appservice interested in this user.
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
self.hs.get_datastores().main.services_cache = [appservice]
self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
[appservice]
)
# Setup a response.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
)
# Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
},
},
)
# Now upload a fallback key.
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(res, [])
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# The appservice will return only the OTK.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_otk}}, [])
)
# Claim OTKs, which should return the OTK from the appservice and the
# uploaded fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {
local_user: {device_id_1: {**as_otk, **fallback_key}}
},
},
)
# But the fallback key should not be marked as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# Now upload a OTK.
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id_1,
{"one_time_keys": otk},
)
)
# Claim OTKs, which will return information only from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
},
)
# But the fallback key should not be marked as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
)
self.assertEqual(fallback_res, ["alg1"])
# Finally, return only the fallback key from the appservice.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_fallback_key}}, [])
)
# Claim OTKs, which will return only the fallback key from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id_1: as_fallback_key}},
},
)
@override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
def test_query_local_devices_appservice(self) -> None:
"""Test that querying of appservices for keys overrides responses from the database."""