mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Split fetching device keys and signatures into two transactions (#8233)
I think this is simpler (and moves stuff out of the db threads)
This commit is contained in:
parent
208e1d3eb3
commit
f97f9485ee
2 changed files with 68 additions and 46 deletions
1
changelog.d/8233.misc
Normal file
1
changelog.d/8233.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Refactor queries for device keys and cross-signatures.
|
|
@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import make_in_list_sql_clause
|
from synapse.storage.database import make_in_list_sql_clause
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
|
||||||
# key) and "signatures" (a signature of the structure by the ed25519 key)
|
# key) and "signatures" (a signature of the structure by the ed25519 key)
|
||||||
key_json = attr.ib(type=Optional[str])
|
key_json = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
# cross-signing sigs
|
# cross-signing sigs on this device.
|
||||||
signatures = attr.ib(type=Optional[Dict], default=None)
|
# dict from (signing user_id)->(signing device_id)->sig
|
||||||
|
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
|
@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
include_all_devices: bool = False,
|
include_all_devices: bool = False,
|
||||||
include_deleted_devices: bool = False,
|
include_deleted_devices: bool = False,
|
||||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||||
"""Fetch a list of device keys, together with their cross-signatures.
|
"""Fetch a list of device keys
|
||||||
|
|
||||||
|
Any cross-signatures made on the keys by the owner of the device are also
|
||||||
|
included.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_list: List of pairs of user_ids and device_ids. Device id can be None
|
query_list: List of pairs of user_ids and device_ids. Device id can be None
|
||||||
|
@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
result = await self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
"get_e2e_device_keys",
|
"get_e2e_device_keys",
|
||||||
self._get_e2e_device_keys_and_signatures_txn,
|
self._get_e2e_device_keys_txn,
|
||||||
query_list,
|
query_list,
|
||||||
include_all_devices,
|
include_all_devices,
|
||||||
include_deleted_devices,
|
include_deleted_devices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get the (user_id, device_id) tuples to look up cross-signatures for
|
||||||
|
signature_query = (
|
||||||
|
(user_id, device_id)
|
||||||
|
for user_id, dev in result.items()
|
||||||
|
for device_id, d in dev.items()
|
||||||
|
if d is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch in batch_iter(signature_query, 50):
|
||||||
|
cross_sigs_result = await self.db_pool.runInteraction(
|
||||||
|
"get_e2e_cross_signing_signatures",
|
||||||
|
self._get_e2e_cross_signing_signatures_for_devices_txn,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# add each cross-signing signature to the correct device in the result dict.
|
||||||
|
for (user_id, key_id, device_id, signature) in cross_sigs_result:
|
||||||
|
target_device_result = result[user_id][device_id]
|
||||||
|
target_device_signatures = target_device_result.signatures
|
||||||
|
|
||||||
|
signing_user_signatures = target_device_signatures.setdefault(
|
||||||
|
user_id, {}
|
||||||
|
)
|
||||||
|
signing_user_signatures[key_id] = signature
|
||||||
|
|
||||||
log_kv(result)
|
log_kv(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_e2e_device_keys_and_signatures_txn(
|
def _get_e2e_device_keys_txn(
|
||||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||||
|
"""Get information on devices from the database
|
||||||
|
|
||||||
|
The results include the device's keys and self-signatures, but *not* any
|
||||||
|
cross-signing signatures which have been added subsequently (for which, see
|
||||||
|
get_e2e_device_keys_and_signatures)
|
||||||
|
"""
|
||||||
query_clauses = []
|
query_clauses = []
|
||||||
query_params = []
|
query_params = []
|
||||||
signature_query_clauses = []
|
|
||||||
signature_query_params = []
|
|
||||||
|
|
||||||
if include_all_devices is False:
|
if include_all_devices is False:
|
||||||
include_deleted_devices = False
|
include_deleted_devices = False
|
||||||
|
@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
for (user_id, device_id) in query_list:
|
for (user_id, device_id) in query_list:
|
||||||
query_clause = "user_id = ?"
|
query_clause = "user_id = ?"
|
||||||
query_params.append(user_id)
|
query_params.append(user_id)
|
||||||
signature_query_clause = "target_user_id = ?"
|
|
||||||
signature_query_params.append(user_id)
|
|
||||||
|
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
query_clause += " AND device_id = ?"
|
query_clause += " AND device_id = ?"
|
||||||
query_params.append(device_id)
|
query_params.append(device_id)
|
||||||
signature_query_clause += " AND target_device_id = ?"
|
|
||||||
signature_query_params.append(device_id)
|
|
||||||
|
|
||||||
signature_query_clause += " AND user_id = ?"
|
|
||||||
signature_query_params.append(user_id)
|
|
||||||
|
|
||||||
query_clauses.append(query_clause)
|
query_clauses.append(query_clause)
|
||||||
signature_query_clauses.append(signature_query_clause)
|
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT user_id, device_id, "
|
"SELECT user_id, device_id, "
|
||||||
|
@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
for user_id, device_id in deleted_devices:
|
for user_id, device_id in deleted_devices:
|
||||||
result.setdefault(user_id, {})[device_id] = None
|
result.setdefault(user_id, {})[device_id] = None
|
||||||
|
|
||||||
# get signatures on the device
|
return result
|
||||||
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
|
|
||||||
|
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||||
|
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
|
||||||
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
|
"""Get cross-signing signatures for a given list of devices
|
||||||
|
|
||||||
|
Returns signatures made by the owners of the devices.
|
||||||
|
|
||||||
|
Returns: a list of results; each entry in the list is a tuple of
|
||||||
|
(user_id, key_id, target_device_id, signature).
|
||||||
|
"""
|
||||||
|
signature_query_clauses = []
|
||||||
|
signature_query_params = []
|
||||||
|
|
||||||
|
for (user_id, device_id) in device_query:
|
||||||
|
signature_query_clauses.append(
|
||||||
|
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
|
||||||
|
)
|
||||||
|
signature_query_params.extend([user_id, device_id, user_id])
|
||||||
|
|
||||||
|
signature_sql = """
|
||||||
|
SELECT user_id, key_id, target_device_id, signature
|
||||||
|
FROM e2e_cross_signing_signatures WHERE %s
|
||||||
|
""" % (
|
||||||
" OR ".join("(" + q + ")" for q in signature_query_clauses)
|
" OR ".join("(" + q + ")" for q in signature_query_clauses)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(signature_sql, signature_query_params)
|
txn.execute(signature_sql, signature_query_params)
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
return txn.fetchall()
|
||||||
|
|
||||||
# add each cross-signing signature to the correct device in the result dict.
|
|
||||||
for row in rows:
|
|
||||||
signing_user_id = row["user_id"]
|
|
||||||
signing_key_id = row["key_id"]
|
|
||||||
target_user_id = row["target_user_id"]
|
|
||||||
target_device_id = row["target_device_id"]
|
|
||||||
signature = row["signature"]
|
|
||||||
|
|
||||||
target_user_result = result.get(target_user_id)
|
|
||||||
if not target_user_result:
|
|
||||||
continue
|
|
||||||
|
|
||||||
target_device_result = target_user_result.get(target_device_id)
|
|
||||||
if not target_device_result:
|
|
||||||
# note that target_device_result will be None for deleted devices.
|
|
||||||
continue
|
|
||||||
|
|
||||||
target_device_signatures = target_device_result.signatures
|
|
||||||
if target_device_signatures is None:
|
|
||||||
target_device_signatures = target_device_result.signatures = {}
|
|
||||||
|
|
||||||
signing_user_signatures = target_device_signatures.setdefault(
|
|
||||||
signing_user_id, {}
|
|
||||||
)
|
|
||||||
signing_user_signatures[signing_key_id] = signature
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def get_e2e_one_time_keys(
|
async def get_e2e_one_time_keys(
|
||||||
self, user_id: str, device_id: str, key_ids: List[str]
|
self, user_id: str, device_id: str, key_ids: List[str]
|
||||||
|
|
Loading…
Reference in a new issue