Handle local device list updates during partial join (#13934)

This commit is contained in:
Erik Johnston 2022-09-28 23:22:35 +01:00 committed by GitHub
parent df8b91ed2b
commit 5f659d4a88
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 141 additions and 15 deletions

1
changelog.d/13934.misc Normal file
View file

@ -0,0 +1 @@
Correctly handle sending local device list updates to remote servers during a partial join.

View file

@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler):
gone from partial to full state. gone from partial to full state.
""" """
# We defer to the device list updater implementation as we're on the # We defer to the device list updater to handle pending remote device
# right worker. # list updates.
await self.device_list_updater.handle_room_un_partial_stated(room_id) await self.device_list_updater.handle_room_un_partial_stated(room_id)
# Replay local updates.
(
join_event_id,
device_lists_stream_id,
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
room_id
)
# Get the local device list changes that have happened in the room since
# we started joining. If there are no updates there's nothing left to do.
changes = await self.store.get_device_list_changes_in_room(
room_id, device_lists_stream_id
)
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
if not local_changes:
return
# Note: We have persisted the full state at this point, we just haven't
# cleared the `partial_room` flag.
join_state_ids = await self._state_storage.get_state_ids_for_event(
join_event_id, await_full_state=False
)
current_state_ids = await self.store.get_partial_current_state_ids(room_id)
# Now we need to work out all servers that might have been in the room
# at any point during our join.
# First we look for any membership states that have changed between the
# initial join and now...
all_keys = set(join_state_ids)
all_keys.update(current_state_ids)
potentially_changed_hosts = set()
for etype, state_key in all_keys:
if etype != EventTypes.Member:
continue
prev = join_state_ids.get((etype, state_key))
current = current_state_ids.get((etype, state_key))
if prev != current:
potentially_changed_hosts.add(get_domain_from_id(state_key))
# ... then we add all the hosts that are currently joined to the room...
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
potentially_changed_hosts.update(current_hosts_in_room)
# ... and finally we remove any hosts that we were told about, as we
# will have sent device list updates to those hosts when they happened.
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
potentially_changed_hosts.difference_update(known_hosts_at_join)
potentially_changed_hosts.discard(self.server_name)
if not potentially_changed_hosts:
# Nothing to do.
return
logger.info(
"Found %d changed hosts to send device list updates to",
len(potentially_changed_hosts),
)
for user_id, device_id in local_changes:
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)
# Notify things that device lists need to be sent out.
self.notifier.notify_replication()
for host in potentially_changed_hosts:
self.federation_sender.send_device_messages(host, immediate=False)
def _update_device_from_client_ips( def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]

View file

@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes return changes
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
"""Get all device list changes that happened in the room since the given
stream ID.
Returns:
Collection of user ID/device ID tuples of all devices that have
changed
"""
sql = """
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
WHERE room_id = ? AND stream_id > ?
"""
def get_device_list_changes_in_room_txn(
txn: LoggingTransaction,
) -> Collection[Tuple[str, str]]:
txn.execute(sql, (room_id, min_stream_id))
return cast(Collection[Tuple[str, str]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_device_list_changes_in_room",
get_device_list_changes_in_room_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__( def __init__(
@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str, user_id: str,
device_id: str, device_id: str,
room_id: str, room_id: str,
stream_id: int, stream_id: Optional[int],
hosts: Collection[str], hosts: Collection[str],
context: Optional[Dict[str, str]], context: Optional[Dict[str, str]],
) -> None: ) -> None:
"""Queue the device update to be sent to the given set of hosts, """Queue the device update to be sent to the given set of hosts,
calculated from the room ID. calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled. Marks the associated row in `device_lists_changes_in_room` as handled,
if `stream_id` is provided.
""" """
def add_device_list_outbound_pokes_txn( def add_device_list_outbound_pokes_txn(
@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context=context, context=context,
) )
self.db_pool.simple_update_txn( if stream_id:
txn, self.db_pool.simple_update_txn(
table="device_lists_changes_in_room", txn,
keyvalues={ table="device_lists_changes_in_room",
"user_id": user_id, keyvalues={
"device_id": device_id, "user_id": user_id,
"stream_id": stream_id, "device_id": device_id,
"room_id": room_id, "stream_id": stream_id,
}, "room_id": room_id,
updatevalues={"converted_to_destinations": True}, },
) updatevalues={"converted_to_destinations": True},
)
if not hosts: if not hosts:
# If there are no hosts then we don't try and generate stream IDs. # If there are no hosts then we don't try and generate stream IDs.

View file

@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None return entry is not None
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
"""Get the event ID of the initial join that started the partial
join, and the device list stream ID at the point we started the partial
join.
"""
result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
)
return result["join_event_id"], result["device_lists_stream_id"]
class _BackgroundUpdates: class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"