0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-16 09:38:19 +02:00

Add cache to get_server_keys_json_for_remote (#16123)

This commit is contained in:
Erik Johnston 2023-08-18 11:05:01 +01:00 committed by GitHub
parent 54a51ff6c1
commit 0aba4a4eaa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 144 additions and 101 deletions

1
changelog.d/16123.misc Normal file
View file

@ -0,0 +1 @@
Add cache to `get_server_keys_json_for_remote`.

View file

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
store_queries = []
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
if key_ids:
results: Mapping[
str, Optional[FetchKeyResultForRemote]
] = await self.store.get_server_keys_json_for_remote(
server_name, key_ids
)
else:
results = await self.store.get_all_server_keys_json_for_remote(
server_name
)
cached = await self.store.get_server_keys_json_for_remote(store_queries)
server_keys.update(
((server_name, key_id), res) for key_id, res in results.items()
)
json_results: Set[bytes] = set()
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]
if key_id is None:
for (server_name, key_id), key_result in server_keys.items():
if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if key_result:
json_results.add(key_result.key_json)
continue
miss = False
if not results:
if key_result is None:
miss = True
else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist

View file

@ -16,14 +16,13 @@
import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
class KeyStore(SQLBaseStore):
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate(((server_name, key_id),))
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
)
@cached()
def _get_server_keys_json(
@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
@cached()
def get_server_key_json_for_remote(
self,
server_name: str,
key_id: str,
) -> Optional[FetchKeyResultForRemote]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
)
async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.
Args:
server_keys: List of (server_name, key_id, source) triplets.
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
If we have multiple entries for a given key ID, returns the most recent.
"""
def _get_server_keys_json_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
rows = await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}
async def get_all_server_keys_json_for_remote(
self,
server_name: str,
) -> Dict[str, FetchKeyResultForRemote]:
"""Fetch the cached keys for the given server.
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

View file

@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResultForRemote:
key_json: bytes # the full key JSON
valid_until_ts: int # how long we can use this key for, in milliseconds.
added_ts: int # When we added this key, in milliseconds.

View file

@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""