0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 15:01:23 +01:00

Improve robustness when handling a perspective key response by deduplicating received server keys. (#15423)

* Change `store_server_verify_keys` to take a `Mapping[(str, str), FKR]`

This is because we already can't handle duplicate keys — leads to cardinality violation

* Newsfile

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>

---------

Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
This commit is contained in:
reivilibre 2023-04-13 14:35:03 +00:00 committed by GitHub
parent 38272be037
commit edae20f926
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 43 additions and 28 deletions

1
changelog.d/15423.bugfix Normal file
View file

@ -0,0 +1 @@
Improve robustness when handling a perspective key response by deduplicating received server keys.

View file

@ -721,7 +721,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
keys: Dict[str, Dict[str, FetchKeyResult]] = {} keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys: List[Tuple[str, str, FetchKeyResult]] = [] added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
@ -752,9 +752,27 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
# we continue to process the rest of the response # we continue to process the rest of the response
continue continue
added_keys.extend( for key_id, key in processed_response.items():
(server_name, key_id, key) for key_id, key in processed_response.items() dict_key = (server_name, key_id)
) if dict_key in added_keys:
already_present_key = added_keys[dict_key]
logger.warning(
"Duplicate server keys for %s (%s) from perspective %s (%r, %r)",
server_name,
key_id,
perspective_name,
already_present_key,
key,
)
if already_present_key.valid_until_ts > key.valid_until_ts:
# Favour the entry with the largest valid_until_ts,
# as `old_verify_keys` are also collected from this
# response.
continue
added_keys[dict_key] = key
keys.setdefault(server_name, {}).update(processed_response) keys.setdefault(server_name, {}).update(processed_response)
await self.store.store_server_verify_keys( await self.store.store_server_verify_keys(

View file

@ -15,7 +15,7 @@
import itertools import itertools
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -95,7 +95,7 @@ class KeyStore(SQLBaseStore):
self, self,
from_server: str, from_server: str,
ts_added_ms: int, ts_added_ms: int,
verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
) -> None: ) -> None:
"""Stores NACL verification keys for remote servers. """Stores NACL verification keys for remote servers.
Args: Args:
@ -108,7 +108,7 @@ class KeyStore(SQLBaseStore):
key_values = [] key_values = []
value_values = [] value_values = []
invalidations = [] invalidations = []
for server_name, key_id, fetch_result in verify_keys: for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id)) key_values.append((server_name, key_id))
value_values.append( value_values.append(
( (

View file

@ -193,7 +193,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
r = self.hs.get_datastores().main.store_server_verify_keys( r = self.hs.get_datastores().main.store_server_verify_keys(
"server9", "server9",
int(time.time() * 1000), int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)},
) )
self.get_success(r) self.get_success(r)
@ -291,7 +291,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# None is not a valid value in FetchKeyResult, but we're abusing this # None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted # API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys. # to 0 when fetched in KeyStore.get_server_verify_keys.
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], # type: ignore[arg-type] {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
) )
self.get_success(r) self.get_success(r)

View file

@ -46,10 +46,10 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
store.store_server_verify_keys( store.store_server_verify_keys(
"from_server", "from_server",
10, 10,
[ {
("server1", key_id_1, FetchKeyResult(KEY_1, 100)), ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
("server1", key_id_2, FetchKeyResult(KEY_2, 200)), ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
], },
) )
) )
@ -90,10 +90,10 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
store.store_server_verify_keys( store.store_server_verify_keys(
"from_server", "from_server",
0, 0,
[ {
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
], },
) )
) )
@ -119,7 +119,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
signedjson.key.generate_signing_key("key2") signedjson.key.generate_signing_key("key2")
) )
d = store.store_server_verify_keys( d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
) )
self.get_success(d) self.get_success(d)

View file

@ -793,16 +793,12 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
hs.get_datastores().main.store_server_verify_keys( hs.get_datastores().main.store_server_verify_keys(
from_server=self.OTHER_SERVER_NAME, from_server=self.OTHER_SERVER_NAME,
ts_added_ms=clock.time_msec(), ts_added_ms=clock.time_msec(),
verify_keys=[ verify_keys={
( (self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult(
self.OTHER_SERVER_NAME, verify_key=verify_key,
verify_key_id, valid_until_ts=clock.time_msec() + 10000,
FetchKeyResult( ),
verify_key=verify_key, },
valid_until_ts=clock.time_msec() + 10000,
),
)
],
) )
) )