0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-08 21:58:53 +02:00

Add support for claiming multiple OTKs at once. (#15468)

MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
This commit is contained in:
Patrick Cloke 2023-04-27 12:57:46 -04:00 committed by GitHub
parent 6efa674004
commit 57aeeb308b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 271 additions and 98 deletions

1
changelog.d/15468.misc Normal file
View file

@ -0,0 +1 @@
Support claiming more than one OTK at a time.

View file

@ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient):
return False
async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from an application service.
Note that any error (including a timeout) is treated as the application
@ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient):
# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm in query:
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
)
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
@ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient):
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
missing = [
(user_id, device, algorithm)
for user_id, device, algorithm in query
if algorithm not in response.get(user_id, {}).get(device, [])
]
missing = []
for user_id, device, algorithm, count in query:
# Count the number of keys in the response for this algorithm by
# checking which key IDs start with the algorithm. This uses that
# True == 1 in Python to generate a count.
response_count = sum(
key_id.startswith(f"{algorithm}:")
for key_id in response.get(user_id, {}).get(device, {})
)
count -= response_count
# If the appservice responds with fewer keys than requested, then
# consider the request unfulfilled.
if count > 0:
missing.append((user_id, device, algorithm, count))
return response, missing

View file

@ -235,7 +235,10 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: Optional[int]
self,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@ -247,6 +250,50 @@ class FederationClient(FederationBase):
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
use_unstable = False
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
if any(count > 1 for count in algorithms.values()):
use_unstable = True
if algorithms:
# For the stable query, choose only the first algorithm.
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
# For the unstable query, repeat each algorithm by count, then
# splat those into chain to get a flattened list of all algorithms.
#
# Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
unstable_content.setdefault(user_id, {})[device_id] = list(
itertools.chain(
*(
itertools.repeat(algorithm, count)
for algorithm, count in algorithms.items()
)
)
)
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
)
else:
logger.debug("Skipping unstable claim client keys API")
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)

View file

@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys

View file

@ -650,10 +650,10 @@ class TransportLayerClient:
Response:
{
"device_keys": {
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
"<algorithm>:<key_id>": <OTK JSON>
}
}
}
@ -669,7 +669,50 @@ class TransportLayerClient:
path = _create_v1_path("/user/keys/claim")
return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)
async def claim_client_keys_unstable(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {"<algorithm>": <count>}
}
}
}
Response:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": <OTK JSON>
}
}
}
}
Args:
destination: The server to query.
query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
return await self.client.post_json(
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)
async def get_missing_events(

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
key_query.append((user_id, device_id, algorithm, 1))
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
key_query, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
Identical to the stable endpoint (FederationClientKeysClaimServlet) except
it allows for querying for multiple OTKs at once and always includes fallback
keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
counts = Counter(algorithms)
for algorithm, count in counts.items():
key_query.append((user_id, device_id, algorithm, count))
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
key_query, always_include_fallback_keys=True
)
return 200, response
@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,

View file

@ -841,8 +841,10 @@ class ApplicationServicesHandler:
return True
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
@ -863,18 +865,18 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = []
for user_id, device, algorithm in query:
for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
missing.append((user_id, device, algorithm, count))
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
(user_id, device, algorithm, count)
)
continue

View file

@ -564,7 +564,7 @@ class E2eKeysHandler:
async def claim_local_one_time_keys(
self,
local_query: List[Tuple[str, str, str]],
local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
@ -581,6 +581,12 @@ class E2eKeysHandler:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
# Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [
(user_id, device_id, algorithm, min(count, 5))
for user_id, device_id, algorithm, count in local_query
]
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S
@ -607,7 +613,7 @@ class E2eKeysHandler:
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
@ -630,13 +636,17 @@ class E2eKeysHandler:
.get(device_id, {})
.keys()
)
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
for user_id, device_id, algorithm, count in not_found
]
# For each user that does not have a one-time keys available, see if
@ -650,18 +660,19 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self,
query: Dict[str, Dict[str, Dict[str, str]]],
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
for device_id, algorithms in one_time_keys.items():
for algorithm, count in algorithms.items():
local_query.append((user_id, device_id, algorithm, count))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@ -692,7 +703,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:

View file

@ -16,7 +16,8 @@
import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
# Generate a count for each algorithm, which is hard-coded to 1.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in one_time_keys.items():
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False
query, timeout, always_include_fallback_keys=False
)
return 200, result
class UnstableOneTimeKeyServlet(RestServlet):
"""
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
fallback keys in the response.
Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
querying for multiple OTKs at once and always includes fallback keys in the
response.
POST /keys/claim HTTP/1.1
{
"one_time_keys": {
"<user_id>": {
"<device_id>": ["<algorithm>", ...]
} } }
HTTP/1.1 200 OK
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
"""
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
# Generate a count for each algorithm.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithms in one_time_keys.items():
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True
query, timeout, always_include_fallback_keys=True
)
return 200, result

View file

@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, query_list: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Take a list of one time keys out of the database.
Args:
@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@trace
def _claim_e2e_one_time_key_simple(
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
Returns:
@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT 1
LIMIT ?
"""
txn.execute(sql, (user_id, device_id, algorithm))
otk_row = txn.fetchone()
if otk_row is None:
return None
txn.execute(sql, (user_id, device_id, algorithm, count))
otk_rows = list(txn)
if not otk_rows:
return []
key_id, key_json = otk_row
self.db_pool.simple_delete_one_txn(
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
column="key_id",
values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return f"{algorithm}:{key_id}", key_json
return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
@trace
def _claim_e2e_one_time_key_returning(
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
Returns:
@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT 1
LIMIT ?
)
RETURNING key_id, key_json
"""
txn.execute(
sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
otk_row = txn.fetchone()
if otk_row is None:
return None
otk_rows = list(txn)
if not otk_rows:
return []
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
return [
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str]] = []
for user_id, device_id, algorithm in query_list:
missing: List[Tuple[str, str, str, int]] = []
for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
claim_row = await self.db_pool.runInteraction(
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
device_id,
algorithm,
count,
db_autocommit=db_autocommit,
)
if claim_row:
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
else:
missing.append((user_id, device_id, algorithm))
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
missing.append((user_id, device_id, algorithm, count))
return results, missing

View file

@ -195,11 +195,11 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
MISSING_KEYS = [
# Known user, known device, missing algorithm.
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
("@alice:example.org", "DEVICE_2", "xyz", 1),
# Known user, missing device.
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
("@alice:example.org", "DEVICE_3", "signed_curve25519", 1),
# Unknown user.
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
("@bob:example.org", "DEVICE_4", "signed_curve25519", 1),
]
claimed_keys, missing = self.get_success(
@ -207,9 +207,8 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.service,
[
# Found devices
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
("@alice:example.org", "DEVICE_1", "signed_curve25519", 1),
("@alice:example.org", "DEVICE_2", "signed_curve25519", 1),
]
+ MISSING_KEYS,
)

View file

@ -160,7 +160,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -205,7 +205,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -224,7 +224,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# claiming an OTK again should return the same fallback key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -273,7 +273,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -285,7 +285,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -306,7 +306,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -347,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# return both.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)
@ -369,7 +369,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claiming an OTK again should return only the fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}},
{local_user: {device_id: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)
@ -1052,7 +1052,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
)
# we shouldn't have any unused fallback keys yet
@ -1079,11 +1079,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# query the fallback keys.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{
"one_time_keys": {
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
}
},
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=False,
)
@ -1128,7 +1124,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
{local_user: {device_id_1: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)
@ -1172,7 +1168,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# uploaded fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
{local_user: {device_id_1: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)
@ -1205,7 +1201,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will return information only from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
{local_user: {device_id_1: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)
@ -1232,7 +1228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Claim OTKs, which will return only the fallback key from the database.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
{local_user: {device_id_1: {"alg1": 1}}},
timeout=None,
always_include_fallback_keys=True,
)