From f162c92f2a1f8cf41b5f7211cb31cf8ae49b20e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 16 Aug 2024 16:04:02 +0100 Subject: [PATCH] Speed up `/keys/changes` (#17548) Follow on from #17537. This is just adding a batched lookup function (you might want to hide whitespace in the diff). --- changelog.d/17548.misc | 1 + synapse/handlers/device.py | 32 ++++------ .../storage/databases/main/state_deltas.py | 64 ++++++++++++++++++- 3 files changed, 77 insertions(+), 20 deletions(-) create mode 100644 changelog.d/17548.misc diff --git a/changelog.d/17548.misc b/changelog.d/17548.misc new file mode 100644 index 000000000..861b241dc --- /dev/null +++ b/changelog.d/17548.misc @@ -0,0 +1 @@ +Fix performance of device lists in `/key/changes` and sliding sync. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ce26c91a7..4f2a9f3a5 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -267,31 +267,27 @@ class DeviceWorkerHandler: newly_left_rooms.add(change.room_id) # We now work out if any other users have since joined or left the rooms - # the user is currently in. First we filter out rooms that we know - # haven't changed recently. - rooms_changed = self.store.get_rooms_that_changed( - joined_room_ids, from_token.room_key - ) + # the user is currently in. # List of membership changes per room room_to_deltas: Dict[str, List[StateDelta]] = {} # The set of event IDs of membership events (so we can fetch their # associated membership). memberships_to_fetch: Set[str] = set() - for room_id in rooms_changed: - # TODO: Only pull out membership events? - state_changes = await self.store.get_current_state_deltas_for_room( - room_id, from_token=from_token.room_key, to_token=now_token.room_key - ) - for delta in state_changes: - if delta.event_type != EventTypes.Member: - continue - room_to_deltas.setdefault(room_id, []).append(delta) - if delta.event_id: - memberships_to_fetch.add(delta.event_id) - if delta.prev_event_id: - memberships_to_fetch.add(delta.prev_event_id) + # TODO: Only pull out membership events? + state_changes = await self.store.get_current_state_deltas_for_rooms( + joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key + ) + for delta in state_changes: + if delta.event_type != EventTypes.Member: + continue + + room_to_deltas.setdefault(delta.room_id, []).append(delta) + if delta.event_id: + memberships_to_fetch.add(delta.event_id) + if delta.prev_event_id: + memberships_to_fetch.add(delta.prev_event_id) # Fetch all the memberships for the membership events event_id_to_memberships = await self.store.get_membership_from_event_ids( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 7d491d172..eaa13da36 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -26,10 +26,11 @@ import attr from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import _filter_results_by_stream -from synapse.types import RoomStreamToken +from synapse.types import RoomStreamToken, StrCollection from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -200,3 +201,62 @@ class StateDeltasStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn ) + + @trace + async def get_current_state_deltas_for_rooms( + self, + room_ids: StrCollection, + from_token: RoomStreamToken, + to_token: RoomStreamToken, + ) -> List[StateDelta]: + """Get the state deltas between two tokens for the set of rooms.""" + + room_ids = self._curr_state_delta_stream_cache.get_entities_changed( + room_ids, from_token.stream + ) + if not room_ids: + return [] + + def get_current_state_deltas_for_rooms_txn( + txn: LoggingTransaction, + room_ids: StrCollection, + ) -> List[StateDelta]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE {clause} AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + StateDelta( + stream_id=row[1], + room_id=row[2], + event_type=row[3], + state_key=row[4], + event_id=row[5], + prev_event_id=row[6], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + + results = [] + for batch in batch_iter(room_ids, 1000): + deltas = await self.db_pool.runInteraction( + "get_current_state_deltas_for_rooms", + get_current_state_deltas_for_rooms_txn, + batch, + ) + + results.extend(deltas) + + return results