mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 09:53:47 +01:00
Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224)
this makes it a bit clearer what's going on.
This commit is contained in:
parent
b939251c37
commit
abeab964d5
3 changed files with 41 additions and 20 deletions
1
changelog.d/8224.misc
Normal file
1
changelog.d/8224.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
prev_id = stream_id
|
||||
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
key_json = device.key_json
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
if device.signatures:
|
||||
for sig_user_id, sigs in device.signatures.items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
device_display_name = device.display_name
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
else:
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import abc
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
|
@ -33,6 +34,21 @@ if TYPE_CHECKING:
|
|||
from synapse.handlers.e2e_keys import SignatureListItem
|
||||
|
||||
|
||||
@attr.s
|
||||
class DeviceKeyLookupResult:
|
||||
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
|
||||
|
||||
display_name = attr.ib(type=Optional[str])
|
||||
|
||||
# the key data from e2e_device_keys_json. Typically includes fields like
|
||||
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
|
||||
# key) and "signatures" (a signature of the structure by the ed25519 key)
|
||||
key_json = attr.ib(type=Optional[str])
|
||||
|
||||
# cross-signing sigs
|
||||
signatures = attr.ib(type=Optional[Dict], default=None)
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
async def get_e2e_device_keys_for_federation_query(
|
||||
self, user_id: str
|
||||
|
@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
key_json = device.key_json
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
if device.signatures:
|
||||
for sig_user_id, sigs in device.signatures.items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
device_display_name = device.display_name
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
|
@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
for user_id, device_keys in results.items():
|
||||
rv[user_id] = {}
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = db_to_json(device_info.pop("key_json"))
|
||||
r = db_to_json(device_info.key_json)
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info["device_display_name"]
|
||||
display_name = device_info.display_name
|
||||
if display_name is not None:
|
||||
r["unsigned"]["device_display_name"] = display_name
|
||||
if "signatures" in device_info:
|
||||
for sig_user_id, sigs in device_info["signatures"].items():
|
||||
if device_info.signatures:
|
||||
for sig_user_id, sigs in device_info.signatures.items():
|
||||
r.setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
@trace
|
||||
def _get_e2e_device_keys_and_signatures_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
) -> Dict[str, Dict[str, Optional[Dict]]]:
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
set_tag("include_all_devices", include_all_devices)
|
||||
set_tag("include_deleted_devices", include_deleted_devices)
|
||||
|
||||
|
@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
|
||||
sql = (
|
||||
"SELECT user_id, device_id, "
|
||||
" d.display_name AS device_display_name, "
|
||||
" d.display_name, "
|
||||
" k.key_json"
|
||||
" FROM devices d"
|
||||
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
|
||||
|
@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
|
||||
for (user_id, device_id, display_name, key_json) in txn:
|
||||
if include_deleted_devices:
|
||||
deleted_devices.remove((row["user_id"], row["device_id"]))
|
||||
result.setdefault(row["user_id"], {})[row["device_id"]] = row
|
||||
deleted_devices.remove((user_id, device_id))
|
||||
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
|
||||
display_name, key_json
|
||||
)
|
||||
|
||||
if include_deleted_devices:
|
||||
for user_id, device_id in deleted_devices:
|
||||
|
@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
# note that target_device_result will be None for deleted devices.
|
||||
continue
|
||||
|
||||
target_device_signatures = target_device_result.setdefault("signatures", {})
|
||||
target_device_signatures = target_device_result.signatures
|
||||
if target_device_signatures is None:
|
||||
target_device_signatures = target_device_result.signatures = {}
|
||||
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
signing_user_id, {}
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue