mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-16 17:03:52 +01:00
Issue one time keys in upload order (#17903)
Currently, one-time-keys are issued in a somewhat random order. (In practice, they are issued according to the lexicographical order of their key IDs.) That can lead to a situation where a client gives up hope of a given OTK ever being used, whilst it is still on the server. Related: https://github.com/element-hq/element-meta/issues/2356
This commit is contained in:
parent
eda735e4bb
commit
2a321bac35
5 changed files with 116 additions and 8 deletions
1
changelog.d/17903.bugfix
Normal file
1
changelog.d/17903.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.
|
|
@ -615,7 +615,7 @@ class E2eKeysHandler:
|
|||
3. Attempt to fetch fallback keys from the database.
|
||||
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
|
|||
unique=True,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
update_name="add_otk_ts_added_index",
|
||||
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
|
||||
table="e2e_one_time_keys_json",
|
||||
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
|
||||
)
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
|
@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
"""Take a list of one time keys out of the database.
|
||||
|
||||
Args:
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
|
||||
|
||||
Returns:
|
||||
A tuple (results, missing) of:
|
||||
|
@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
OTK was found.
|
||||
"""
|
||||
|
||||
# Return the oldest keys from this device (based on `ts_added_ms`).
|
||||
# Doing so means that keys are issued in the same order they were uploaded,
|
||||
# which reduces the chances of a client expiring its copy of a (private)
|
||||
# key while the public key is still on the server, waiting to be issued.
|
||||
sql = """
|
||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
ORDER BY ts_added_ms
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
|
||||
for each OTK claimed.
|
||||
"""
|
||||
# Find, delete, and return the oldest keys from each device (based on
|
||||
# `ts_added_ms`).
|
||||
#
|
||||
# Doing so means that keys are issued in the same order they were uploaded,
|
||||
# which reduces the chances of a client expiring its copy of a (private)
|
||||
# key while the public key is still on the server, waiting to be issued.
|
||||
sql = """
|
||||
WITH claims(user_id, device_id, algorithm, claim_count) AS (
|
||||
VALUES ?
|
||||
), ranked_keys AS (
|
||||
SELECT
|
||||
user_id, device_id, algorithm, key_id, claim_count,
|
||||
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY (user_id, device_id, algorithm)
|
||||
ORDER BY ts_added_ms
|
||||
) AS r
|
||||
FROM e2e_one_time_keys_json
|
||||
JOIN claims USING (user_id, device_id, algorithm)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
|
||||
-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
|
||||
-- efficiently be issued in the same order they were uploaded.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(8803, 'add_otk_ts_added_index', '{}');
|
|
@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
def test_claim_one_time_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
keys = {"alg1:k1": "key1"}
|
||||
|
||||
res = self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(
|
||||
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
||||
)
|
||||
|
||||
res2 = self.get_success(
|
||||
# Keys should be returned in the order they were uploaded. To test, advance time
|
||||
# a little, then upload a second key with an earlier key ID; it should get
|
||||
# returned second.
|
||||
self.reactor.advance(1)
|
||||
res = self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(
|
||||
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
|
||||
)
|
||||
|
||||
# now claim both keys back. They should be in the same order
|
||||
res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
self.requester,
|
||||
|
@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res2,
|
||||
res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
||||
},
|
||||
)
|
||||
res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
self.requester,
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"failures": {},
|
||||
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
|
||||
},
|
||||
)
|
||||
|
||||
def test_claim_one_time_key_bulk(self) -> None:
|
||||
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
|
||||
|
@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
|
||||
)
|
||||
|
||||
def test_claim_one_time_key_bulk_ordering(self) -> None:
|
||||
"""Keys returned by the bulk claim call should be returned in the correct order"""
|
||||
|
||||
# Alice has lots of keys, uploaded in a specific order
|
||||
alice = f"@alice:{self.hs.hostname}"
|
||||
alice_dev = "alice_dev_1"
|
||||
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
alice,
|
||||
alice_dev,
|
||||
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
|
||||
)
|
||||
)
|
||||
# Advance time by 1s, to ensure that there is a difference in upload time.
|
||||
self.reactor.advance(1)
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
alice,
|
||||
alice_dev,
|
||||
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
|
||||
)
|
||||
)
|
||||
|
||||
# Now claim some, and check we get the right ones.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{alice: {alice_dev: {"alg1": 2}}},
|
||||
self.requester,
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
# We should get the first-uploaded keys, even though they have later key ids.
|
||||
# We should get a random set of two of k20, k21, k22.
|
||||
self.assertEqual(claim_res["failures"], {})
|
||||
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
|
||||
self.assertEqual(len(claimed_keys), 2)
|
||||
for key_id in claimed_keys.keys():
|
||||
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
|
||||
|
||||
def test_fallback_key(self) -> None:
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
|
|
Loading…
Reference in a new issue