mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 19:23:53 +01:00
Merge pull request #5727 from matrix-org/uhoreg/e2e_cross-signing2-part3
Cross-signing [4/4] -- federation edition
This commit is contained in:
commit
53d7680e32
7 changed files with 241 additions and 47 deletions
1
changelog.d/5727.feature
Normal file
1
changelog.d/5727.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add federation support for cross-signing.
|
|
@ -360,20 +360,20 @@ class PerDestinationQueue(object):
|
||||||
last_device_list = self._last_device_list_stream_id
|
last_device_list = self._last_device_list_stream_id
|
||||||
|
|
||||||
# Retrieve list of new device updates to send to the destination
|
# Retrieve list of new device updates to send to the destination
|
||||||
now_stream_id, results = yield self._store.get_devices_by_remote(
|
now_stream_id, results = yield self._store.get_device_updates_by_remote(
|
||||||
self._destination, last_device_list, limit=limit
|
self._destination, last_device_list, limit=limit
|
||||||
)
|
)
|
||||||
edus = [
|
edus = [
|
||||||
Edu(
|
Edu(
|
||||||
origin=self._server_name,
|
origin=self._server_name,
|
||||||
destination=self._destination,
|
destination=self._destination,
|
||||||
edu_type="m.device_list_update",
|
edu_type=edu_type,
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
for content in results
|
for (edu_type, content) in results
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
|
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
|
||||||
|
|
||||||
return (edus, now_stream_id)
|
return (edus, now_stream_id)
|
||||||
|
|
||||||
|
|
|
@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_federation_query_user_devices(self, user_id):
|
def on_federation_query_user_devices(self, user_id):
|
||||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||||
return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
|
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||||
|
self_signing_key = yield self.store.get_e2e_cross_signing_key(
|
||||||
|
user_id, "self_signing"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"devices": devices,
|
||||||
|
"master_key": master_key,
|
||||||
|
"self_signing_key": self_signing_key,
|
||||||
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_left_room(self, user, room_id):
|
def user_left_room(self, user, room_id):
|
||||||
|
|
|
@ -36,6 +36,8 @@ from synapse.types import (
|
||||||
get_verify_key_from_cross_signing_key,
|
get_verify_key_from_cross_signing_key,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -49,10 +51,19 @@ class E2eKeysHandler(object):
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine = hs.is_mine
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self._edu_updater = SigningKeyEduUpdater(hs, self)
|
||||||
|
|
||||||
|
federation_registry = hs.get_federation_registry()
|
||||||
|
|
||||||
|
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||||
|
federation_registry.register_edu_handler(
|
||||||
|
"org.matrix.signing_key_update",
|
||||||
|
self._edu_updater.incoming_signing_key_update,
|
||||||
|
)
|
||||||
# doesn't really work as part of the generic query API, because the
|
# doesn't really work as part of the generic query API, because the
|
||||||
# query request requires an object POST, but we abuse the
|
# query request requires an object POST, but we abuse the
|
||||||
# "query handler" interface.
|
# "query handler" interface.
|
||||||
hs.get_federation_registry().register_query_handler(
|
federation_registry.register_query_handler(
|
||||||
"client_keys", self.on_federation_query_client_keys
|
"client_keys", self.on_federation_query_client_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -208,13 +219,15 @@ class E2eKeysHandler(object):
|
||||||
if user_id in destination_query:
|
if user_id in destination_query:
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
for user_id, key in remote_result["master_keys"].items():
|
if "master_keys" in remote_result:
|
||||||
if user_id in destination_query:
|
for user_id, key in remote_result["master_keys"].items():
|
||||||
cross_signing_keys["master_keys"][user_id] = key
|
if user_id in destination_query:
|
||||||
|
cross_signing_keys["master_keys"][user_id] = key
|
||||||
|
|
||||||
for user_id, key in remote_result["self_signing_keys"].items():
|
if "self_signing_keys" in remote_result:
|
||||||
if user_id in destination_query:
|
for user_id, key in remote_result["self_signing_keys"].items():
|
||||||
cross_signing_keys["self_signing_keys"][user_id] = key
|
if user_id in destination_query:
|
||||||
|
cross_signing_keys["self_signing_keys"][user_id] = key
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failure = _exception_to_failure(e)
|
failure = _exception_to_failure(e)
|
||||||
|
@ -252,7 +265,7 @@ class E2eKeysHandler(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
defer.Deferred[dict[str, dict[str, dict]]]: map from
|
||||||
(master|self_signing|user_signing) -> user_id -> key
|
(master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
|
||||||
"""
|
"""
|
||||||
master_keys = {}
|
master_keys = {}
|
||||||
self_signing_keys = {}
|
self_signing_keys = {}
|
||||||
|
@ -344,7 +357,16 @@ class E2eKeysHandler(object):
|
||||||
"""
|
"""
|
||||||
device_keys_query = query_body.get("device_keys", {})
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
res = yield self.query_local_devices(device_keys_query)
|
res = yield self.query_local_devices(device_keys_query)
|
||||||
return {"device_keys": res}
|
ret = {"device_keys": res}
|
||||||
|
|
||||||
|
# add in the cross-signing keys
|
||||||
|
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||||
|
device_keys_query, None
|
||||||
|
)
|
||||||
|
|
||||||
|
ret.update(cross_signing_keys)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -1058,3 +1080,100 @@ class SignatureListItem:
|
||||||
target_user_id = attr.ib()
|
target_user_id = attr.ib()
|
||||||
target_device_id = attr.ib()
|
target_device_id = attr.ib()
|
||||||
signature = attr.ib()
|
signature = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
class SigningKeyEduUpdater(object):
|
||||||
|
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||||
|
|
||||||
|
def __init__(self, hs, e2e_keys_handler):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.federation = hs.get_federation_client()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.e2e_keys_handler = e2e_keys_handler
|
||||||
|
|
||||||
|
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||||
|
|
||||||
|
# user_id -> list of updates waiting to be handled.
|
||||||
|
self._pending_updates = {}
|
||||||
|
|
||||||
|
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||||
|
# but they're useful to have them about to reduce the number of spurious
|
||||||
|
# resyncs.
|
||||||
|
self._seen_updates = ExpiringCache(
|
||||||
|
cache_name="signing_key_update_edu",
|
||||||
|
clock=self.clock,
|
||||||
|
max_len=10000,
|
||||||
|
expiry_ms=30 * 60 * 1000,
|
||||||
|
iterable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def incoming_signing_key_update(self, origin, edu_content):
|
||||||
|
"""Called on incoming signing key update from federation. Responsible for
|
||||||
|
parsing the EDU and adding to pending updates list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (string): the server that sent the EDU
|
||||||
|
edu_content (dict): the contents of the EDU
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = edu_content.pop("user_id")
|
||||||
|
master_key = edu_content.pop("master_key", None)
|
||||||
|
self_signing_key = edu_content.pop("self_signing_key", None)
|
||||||
|
|
||||||
|
if get_domain_from_id(user_id) != origin:
|
||||||
|
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
|
||||||
|
return
|
||||||
|
|
||||||
|
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||||
|
if not room_ids:
|
||||||
|
# We don't share any rooms with this user. Ignore update, as we
|
||||||
|
# probably won't get any further updates.
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pending_updates.setdefault(user_id, []).append(
|
||||||
|
(master_key, self_signing_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._handle_signing_key_updates(user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_signing_key_updates(self, user_id):
|
||||||
|
"""Actually handle pending updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (string): the user whose updates we are processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_handler = self.e2e_keys_handler.device_handler
|
||||||
|
|
||||||
|
with (yield self._remote_edu_linearizer.queue(user_id)):
|
||||||
|
pending_updates = self._pending_updates.pop(user_id, [])
|
||||||
|
if not pending_updates:
|
||||||
|
# This can happen since we batch updates
|
||||||
|
return
|
||||||
|
|
||||||
|
device_ids = []
|
||||||
|
|
||||||
|
logger.info("pending updates: %r", pending_updates)
|
||||||
|
|
||||||
|
for master_key, self_signing_key in pending_updates:
|
||||||
|
if master_key:
|
||||||
|
yield self.store.set_e2e_cross_signing_key(
|
||||||
|
user_id, "master", master_key
|
||||||
|
)
|
||||||
|
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
|
||||||
|
# verify_key is a VerifyKey from signedjson, which uses
|
||||||
|
# .version to denote the portion of the key ID after the
|
||||||
|
# algorithm and colon, which is the device ID
|
||||||
|
device_ids.append(verify_key.version)
|
||||||
|
if self_signing_key:
|
||||||
|
yield self.store.set_e2e_cross_signing_key(
|
||||||
|
user_id, "self_signing", self_signing_key
|
||||||
|
)
|
||||||
|
_, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
self_signing_key
|
||||||
|
)
|
||||||
|
device_ids.append(verify_key.version)
|
||||||
|
|
||||||
|
yield device_handler.notify_device_update(user_id, device_ids)
|
||||||
|
|
|
@ -37,6 +37,7 @@ from synapse.storage._base import (
|
||||||
make_in_list_sql_clause,
|
make_in_list_sql_clause,
|
||||||
)
|
)
|
||||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
from synapse.types import get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import batch_iter
|
from synapse.util import batch_iter
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||||
|
|
||||||
|
@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_devices_by_remote(self, destination, from_stream_id, limit):
|
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
||||||
"""Get stream of updates to send to remote servers
|
"""Get a stream of device updates to send to the given remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): The host the device updates are intended for
|
||||||
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||||
|
limit (int): Maximum number of device updates to return
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[int, list[dict]]]:
|
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
||||||
current stream id (ie, the stream id of the last update included in the
|
current stream id (ie, the stream id of the last update included in the
|
||||||
response), and the list of updates
|
response), and the list of updates, where each update is a pair of EDU
|
||||||
|
type and EDU contents
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
# stream_id; the rationale being that such a large device list update
|
# stream_id; the rationale being that such a large device list update
|
||||||
# is likely an error.
|
# is likely an error.
|
||||||
updates = yield self.runInteraction(
|
updates = yield self.runInteraction(
|
||||||
"get_devices_by_remote",
|
"get_device_updates_by_remote",
|
||||||
self._get_devices_by_remote_txn,
|
self._get_device_updates_by_remote_txn,
|
||||||
destination,
|
destination,
|
||||||
from_stream_id,
|
from_stream_id,
|
||||||
now_stream_id,
|
now_stream_id,
|
||||||
|
@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
if not updates:
|
if not updates:
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
|
# get the cross-signing keys of the users in the list, so that we can
|
||||||
|
# determine which of the device changes were cross-signing keys
|
||||||
|
users = set(r[0] for r in updates)
|
||||||
|
master_key_by_user = {}
|
||||||
|
self_signing_key_by_user = {}
|
||||||
|
for user in users:
|
||||||
|
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
||||||
|
if cross_signing_key:
|
||||||
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
cross_signing_key
|
||||||
|
)
|
||||||
|
# verify_key is a VerifyKey from signedjson, which uses
|
||||||
|
# .version to denote the portion of the key ID after the
|
||||||
|
# algorithm and colon, which is the device ID
|
||||||
|
master_key_by_user[user] = {
|
||||||
|
"key_info": cross_signing_key,
|
||||||
|
"device_id": verify_key.version,
|
||||||
|
}
|
||||||
|
|
||||||
|
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
||||||
|
user, "self_signing"
|
||||||
|
)
|
||||||
|
if cross_signing_key:
|
||||||
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
|
cross_signing_key
|
||||||
|
)
|
||||||
|
self_signing_key_by_user[user] = {
|
||||||
|
"key_info": cross_signing_key,
|
||||||
|
"device_id": verify_key.version,
|
||||||
|
}
|
||||||
|
|
||||||
# if we have exceeded the limit, we need to exclude any results with the
|
# if we have exceeded the limit, we need to exclude any results with the
|
||||||
# same stream_id as the last row.
|
# same stream_id as the last row.
|
||||||
if len(updates) > limit:
|
if len(updates) > limit:
|
||||||
|
@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
# context which created the Edu.
|
# context which created the Edu.
|
||||||
|
|
||||||
query_map = {}
|
query_map = {}
|
||||||
for update in updates:
|
cross_signing_keys_by_user = {}
|
||||||
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
|
for user_id, device_id, update_stream_id, update_context in updates:
|
||||||
|
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
|
||||||
# Stop processing updates
|
# Stop processing updates
|
||||||
break
|
break
|
||||||
|
|
||||||
key = (update[0], update[1])
|
if (
|
||||||
|
user_id in master_key_by_user
|
||||||
|
and device_id == master_key_by_user[user_id]["device_id"]
|
||||||
|
):
|
||||||
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||||
|
result["master_key"] = master_key_by_user[user_id]["key_info"]
|
||||||
|
elif (
|
||||||
|
user_id in self_signing_key_by_user
|
||||||
|
and device_id == self_signing_key_by_user[user_id]["device_id"]
|
||||||
|
):
|
||||||
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
||||||
|
result["self_signing_key"] = self_signing_key_by_user[user_id][
|
||||||
|
"key_info"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
key = (user_id, device_id)
|
||||||
|
|
||||||
update_context = update[3]
|
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
||||||
update_stream_id = update[2]
|
|
||||||
|
|
||||||
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
if update_stream_id > previous_update_stream_id:
|
||||||
|
query_map[key] = (update_stream_id, update_context)
|
||||||
if update_stream_id > previous_update_stream_id:
|
|
||||||
query_map[key] = (update_stream_id, update_context)
|
|
||||||
|
|
||||||
# If we didn't find any updates with a stream_id lower than the cutoff, it
|
# If we didn't find any updates with a stream_id lower than the cutoff, it
|
||||||
# means that there are more than limit updates all of which have the same
|
# means that there are more than limit updates all of which have the same
|
||||||
|
@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
# devices, in which case E2E isn't going to work well anyway. We'll just
|
# devices, in which case E2E isn't going to work well anyway. We'll just
|
||||||
# skip that stream_id and return an empty list, and continue with the next
|
# skip that stream_id and return an empty list, and continue with the next
|
||||||
# stream_id next time.
|
# stream_id next time.
|
||||||
if not query_map:
|
if not query_map and not cross_signing_keys_by_user:
|
||||||
return stream_id_cutoff, []
|
return stream_id_cutoff, []
|
||||||
|
|
||||||
results = yield self._get_device_update_edus_by_remote(
|
results = yield self._get_device_update_edus_by_remote(
|
||||||
destination, from_stream_id, query_map
|
destination, from_stream_id, query_map
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add the updated cross-signing keys to the results list
|
||||||
|
for user_id, result in iteritems(cross_signing_keys_by_user):
|
||||||
|
result["user_id"] = user_id
|
||||||
|
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
|
||||||
|
results.append(("org.matrix.signing_key_update", result))
|
||||||
|
|
||||||
return now_stream_id, results
|
return now_stream_id, results
|
||||||
|
|
||||||
def _get_devices_by_remote_txn(
|
def _get_device_updates_by_remote_txn(
|
||||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||||
):
|
):
|
||||||
"""Return device update information for a given remote destination
|
"""Return device update information for a given remote destination
|
||||||
|
@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
List: List of device updates
|
List: List of device updates
|
||||||
"""
|
"""
|
||||||
|
# get the list of device updates that need to be sent
|
||||||
sql = """
|
sql = """
|
||||||
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
||||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||||
|
@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
List[Dict]: List of objects representing an device update EDU
|
List[Dict]: List of objects representing an device update EDU
|
||||||
|
|
||||||
"""
|
"""
|
||||||
devices = yield self.runInteraction(
|
devices = (
|
||||||
"_get_e2e_device_keys_txn",
|
yield self.runInteraction(
|
||||||
self._get_e2e_device_keys_txn,
|
"_get_e2e_device_keys_txn",
|
||||||
query_map.keys(),
|
self._get_e2e_device_keys_txn,
|
||||||
include_all_devices=True,
|
query_map.keys(),
|
||||||
include_deleted_devices=True,
|
include_all_devices=True,
|
||||||
|
include_deleted_devices=True,
|
||||||
|
)
|
||||||
|
if query_map
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
else:
|
else:
|
||||||
result["deleted"] = True
|
result["deleted"] = True
|
||||||
|
|
||||||
results.append(result)
|
results.append(("m.device_list_update", result))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
"get_received_txn_response",
|
"get_received_txn_response",
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
"get_devices_by_remote",
|
"get_device_updates_by_remote",
|
||||||
# Bits that user_directory needs
|
# Bits that user_directory needs
|
||||||
"get_user_directory_stream_pos",
|
"get_user_directory_stream_pos",
|
||||||
"get_current_state_deltas",
|
"get_current_state_deltas",
|
||||||
|
@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
retry_timings_res
|
retry_timings_res
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_devices_by_remote.return_value = (0, [])
|
self.datastore.get_device_updates_by_remote.return_value = (0, [])
|
||||||
|
|
||||||
def get_received_txn_response(*args):
|
def get_received_txn_response(*args):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_devices_by_remote(self):
|
def test_get_device_updates_by_remote(self):
|
||||||
device_ids = ["device_id1", "device_id2"]
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
# Add two device updates with a single stream_id
|
# Add two device updates with a single stream_id
|
||||||
|
@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all device updates ever meant for this remote
|
# Get all device updates ever meant for this remote
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"somehost", -1, limit=100
|
"somehost", -1, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
self._check_devices_in_updates(device_ids, device_updates)
|
self._check_devices_in_updates(device_ids, device_updates)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_devices_by_remote_limited(self):
|
def test_get_device_updates_by_remote_limited(self):
|
||||||
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
||||||
|
|
||||||
# first add one device
|
# first add one device
|
||||||
|
@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
#
|
#
|
||||||
|
|
||||||
# first we should get a single update
|
# first we should get a single update
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", -1, limit=100
|
"someotherhost", -1, limit=100
|
||||||
)
|
)
|
||||||
self._check_devices_in_updates(device_ids1, device_updates)
|
self._check_devices_in_updates(device_ids1, device_updates)
|
||||||
|
|
||||||
# Then we should get an empty list back as the 101 devices broke the limit
|
# Then we should get an empty list back as the 101 devices broke the limit
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", now_stream_id, limit=100
|
"someotherhost", now_stream_id, limit=100
|
||||||
)
|
)
|
||||||
self.assertEqual(len(device_updates), 0)
|
self.assertEqual(len(device_updates), 0)
|
||||||
|
|
||||||
# The 101 devices should've been cleared, so we should now just get one device
|
# The 101 devices should've been cleared, so we should now just get one device
|
||||||
# update
|
# update
|
||||||
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
|
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||||
"someotherhost", now_stream_id, limit=100
|
"someotherhost", now_stream_id, limit=100
|
||||||
)
|
)
|
||||||
self._check_devices_in_updates(device_ids3, device_updates)
|
self._check_devices_in_updates(device_ids3, device_updates)
|
||||||
|
@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||||
|
|
||||||
received_device_ids = {update["device_id"] for update in device_updates}
|
received_device_ids = {
|
||||||
|
update["device_id"] for edu_type, update in device_updates
|
||||||
|
}
|
||||||
self.assertEqual(received_device_ids, set(expected_device_ids))
|
self.assertEqual(received_device_ids, set(expected_device_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
Loading…
Reference in a new issue