Only send out device list updates for our own users (#12465)

Broke in #12365
This commit is contained in:
Erik Johnston 2022-04-14 13:05:31 +01:00 committed by GitHub
parent 535a689cfc
commit 0b014eb25e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 8 deletions

View file

@ -0,0 +1 @@
Enable processing of device list updates asynchronously.

View file

@ -649,9 +649,13 @@ class DeviceHandler(DeviceWorkerHandler):
return return
for user_id, device_id, room_id, stream_id, opentracing_context in rows: for user_id, device_id, room_id, stream_id, opentracing_context in rows:
joined_user_ids = await self.store.get_users_in_room(room_id) hosts = set()
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts.discard(self.server_name) # Ignore any users that aren't ours
if self.hs.is_mine_id(user_id):
joined_user_ids = await self.store.get_users_in_room(room_id)
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts.discard(self.server_name)
# Check if we've already sent this update to some hosts # Check if we've already sent this update to some hosts
if current_stream_id == stream_id: if current_stream_id == stream_id:

View file

@ -1703,7 +1703,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
next(stream_id_iterator), next(stream_id_iterator),
user_id, user_id,
device_id, device_id,
False, not self.hs.is_mine_id(
user_id
), # We only need to send out update for *our* users
now, now,
encoded_context if whitelisted_homeserver(destination) else "{}", encoded_context if whitelisted_homeserver(destination) else "{}",
) )

View file

@ -162,7 +162,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(
spec=["send_transaction", "query_user_devices"]
),
) )
def default_config(self): def default_config(self):
@ -218,6 +220,45 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.assertEqual(len(self.edus), 1) self.assertEqual(len(self.edus), 1)
self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
def test_dont_send_device_updates_for_remote_users(self):
"""Check that we don't send device updates for remote users"""
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed(
{
"stream_id": "1",
"user_id": "@user2:host2",
"devices": [{"device_id": "D1"}],
}
)
)
self.get_success(
self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
"device_id": "D1",
"stream_id": "1",
"prev_ids": [],
},
)
)
self.reactor.advance(1)
# We shouldn't see an EDU for that update
self.assertEqual(self.edus, [])
# Check that we did successfully process the inbound EDU (otherwise this
# test would pass if we failed to process the EDU)
devices = self.get_success(
self.hs.get_datastores().main.get_cached_devices_for_user("@user2:host2")
)
self.assertIn("D1", devices)
def test_upload_signatures(self): def test_upload_signatures(self):
"""Uploading signatures on some devices should produce updates for that user""" """Uploading signatures on some devices should produce updates for that user"""

View file

@ -118,7 +118,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s # Add two device updates with sequential `stream_id`s
self.add_device_change("user_id", device_ids, "somehost") self.add_device_change("@user_id:test", device_ids, "somehost")
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = self.get_success( now_stream_id, device_updates = self.get_success(
@ -142,7 +142,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id4", "device_id4",
"device_id5", "device_id5",
] ]
self.add_device_change("user_id", device_ids, "somehost") self.add_device_change("@user_id:test", device_ids, "somehost")
# Get device updates meant for this remote # Get device updates meant for this remote
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(
@ -162,7 +162,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly # Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"] device_ids = ["device_id6", "device_id7"]
self.add_device_change("user_id", device_ids, "somehost") self.add_device_change("@user_id:test", device_ids, "somehost")
# Get the next batch of device updates # Get the next batch of device updates
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(