diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c478e0bc5c..e28c74daf0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -16,6 +16,7 @@ """Contains handlers for federation events.""" import logging +from queue import Empty, PriorityQueue from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -1041,6 +1042,135 @@ class FederationHandler: else: return [] + async def get_backfill_events( + self, room_id: str, event_id_list: list, limit: int + ) -> List[EventBase]: + event_id_results = set() + + # In a PriorityQueue, the lowest valued entries are retrieved first. + # We're using depth as the priority in the queue and tie-break based on + # stream_ordering. Depth is lowest at the oldest-in-time message and + # highest and newest-in-time message. We add events to the queue with a + # negative depth so that we process the newest-in-time messages first + # going backwards in time. stream_ordering follows the same pattern. + queue = PriorityQueue() + + seed_events = await self.store.get_events_as_list(event_id_list) + for seed_event in seed_events: + # Make sure the seed event actually pertains to this room. We also + # need to make sure the depth is available since our whole DAG + # navigation here depends on depth. + if seed_event.room_id == room_id and seed_event.depth: + queue.put( + ( + -seed_event.depth, + -seed_event.internal_metadata.stream_ordering, + seed_event.event_id, + seed_event.type, + ) + ) + + while not queue.empty() and len(event_id_results) < limit: + try: + _, _, event_id, event_type = queue.get_nowait() + except Empty: + break + + if event_id in event_id_results: + continue + + event_id_results.add(event_id) + + if self.hs.config.experimental.msc2716_enabled: + # Try and find any potential historical batches of message history. + # + # First we look for an insertion event connected to the current + # event (by prev_event). If we find any, we'll add them to the queue + # and navigate up the DAG like normal in the next iteration of the + # loop. + connected_insertion_event_backfill_results = ( + await self.store.get_connected_insertion_event_backfill_results( + event_id, limit - len(event_id_results) + ) + ) + logger.debug( + "_get_backfill_events: connected_insertion_event_backfill_results=%s", + connected_insertion_event_backfill_results, + ) + for ( + connected_insertion_event_backfill_item + ) in connected_insertion_event_backfill_results: + if ( + connected_insertion_event_backfill_item.event_id + not in event_id_results + ): + queue.put( + ( + -connected_insertion_event_backfill_item.depth, + -connected_insertion_event_backfill_item.stream_ordering, + connected_insertion_event_backfill_item.event_id, + connected_insertion_event_backfill_item.type, + ) + ) + + # Second, we need to go and try to find any batch events connected + # to a given insertion event (by batch_id). If we find any, we'll + # add them to the queue and navigate up the DAG like normal in the + # next iteration of the loop. + if event_type == EventTypes.MSC2716_INSERTION: + connected_batch_event_backfill_results = ( + await self.store.get_connected_batch_event_backfill_results( + event_id, limit - len(event_id_results) + ) + ) + logger.debug( + "_get_backfill_events: connected_batch_event_backfill_results %s", + connected_batch_event_backfill_results, + ) + for ( + connected_batch_event_backfill_item + ) in connected_batch_event_backfill_results: + if ( + connected_batch_event_backfill_item.event_id + not in event_id_results + ): + queue.put( + ( + -connected_batch_event_backfill_item.depth, + -connected_batch_event_backfill_item.stream_ordering, + connected_batch_event_backfill_item.event_id, + connected_batch_event_backfill_item.type, + ) + ) + + # Now we just look up the DAG by prev_events as normal + connected_prev_event_backfill_results = ( + await self.store.get_connected_prev_event_backfill_results( + event_id, limit - len(event_id_results) + ) + ) + logger.debug( + "_get_backfill_events: prev_event_ids %s", + connected_prev_event_backfill_results, + ) + for ( + connected_prev_event_backfill_item + ) in connected_prev_event_backfill_results: + if connected_prev_event_backfill_item.event_id not in event_id_results: + queue.put( + ( + -connected_prev_event_backfill_item.depth, + -connected_prev_event_backfill_item.stream_ordering, + connected_prev_event_backfill_item.event_id, + connected_prev_event_backfill_item.type, + ) + ) + + events = await self.store.get_events_as_list(event_id_results) + return sorted( + events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + ) + @log_function async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int @@ -1053,6 +1183,34 @@ class FederationHandler: limit = min(limit, 100) events = await self.store.get_backfill_events(room_id, pdu_list, limit) + logger.info( + "old implementation backfill events=%s", + [ + "event_id=%s,depth=%d,body=%s,prevs=%s\n" + % ( + event.event_id, + event.depth, + event.content.get("body", event.type), + event.prev_event_ids(), + ) + for event in events + ], + ) + + events = await self.get_backfill_events(room_id, pdu_list, limit) + logger.info( + "new implementation backfill events=%s", + [ + "event_id=%s,depth=%d,body=%s,prevs=%s\n" + % ( + event.event_id, + event.depth, + event.content.get("body", event.type), + event.prev_event_ids(), + ) + for event in events + ], + ) events = await filter_events_for_server(self.storage, origin, events) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index ab2ed53bce..c9060a594f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -702,38 +702,38 @@ class FederationEventHandler: event.event_id ) - # Maybe we can get lucky and save ourselves a lookup - # by checking the events in the backfill first - insertion_event = event_map[ - insertion_event_id - ] or await self._store.get_event( - insertion_event_id, allow_none=True - ) + # # Maybe we can get lucky and save ourselves a lookup + # # by checking the events in the backfill first + # insertion_event = event_map[ + # insertion_event_id + # ] or await self._store.get_event( + # insertion_event_id, allow_none=True + # ) - if insertion_event: - # Connect the insertion events' `prev_event` successors - # via fake edges pointing to the insertion event itself - # so the insertion event sorts topologically - # behind-in-time the successor. Nestled perfectly - # between the prev_event and the successor. - for insertion_prev_event_id in insertion_event.prev_event_ids(): - successor_event_ids = successor_event_id_map[ - insertion_prev_event_id - ] - logger.info( - "insertion_event_id=%s successor_event_ids=%s", - insertion_event_id, - successor_event_ids, - ) - if successor_event_ids: - for successor_event_id in successor_event_ids: - # Don't add itself back as a successor - if successor_event_id != insertion_event_id: - # Fake edge to point the successor back - # at the insertion event - event_id_graph.setdefault( - successor_event_id, [] - ).append(insertion_event_id) + # if insertion_event: + # # Connect the insertion events' `prev_event` successors + # # via fake edges pointing to the insertion event itself + # # so the insertion event sorts topologically + # # behind-in-time the successor. Nestled perfectly + # # between the prev_event and the successor. + # for insertion_prev_event_id in insertion_event.prev_event_ids(): + # successor_event_ids = successor_event_id_map[ + # insertion_prev_event_id + # ] + # logger.info( + # "insertion_event_id=%s successor_event_ids=%s", + # insertion_event_id, + # successor_event_ids, + # ) + # if successor_event_ids: + # for successor_event_id in successor_event_ids: + # # Don't add itself back as a successor + # if successor_event_id != insertion_event_id: + # # Fake edge to point the successor back + # # at the insertion event + # event_id_graph.setdefault( + # successor_event_id, [] + # ).append(insertion_event_id) # TODO: We also need to add fake edges to connect the oldest-in-time messages # in the batch to the event we branched off of, see https://github.com/matrix-org/synapse/pull/11114#discussion_r739300985 @@ -773,17 +773,17 @@ class FederationEventHandler: # We want to sort these by depth so we process them and # tell clients about them in order. - # sorted_events = sorted(events, key=lambda x: x.depth) + sorted_events = sorted(events, key=lambda x: x.depth) - # We want to sort topologically so we process them and tell clients - # about them in order. - sorted_events = [] - event_ids = [event.event_id for event in events] - event_map = {event.event_id: event for event in events} - event_id_graph = await self.generateEventIdGraphFromEvents(events) - for event_id in sorted_topologically(event_ids, event_id_graph): - sorted_events.append(event_map[event_id]) - sorted_events = reversed(sorted_events) + # # We want to sort topologically so we process them and tell clients + # # about them in order. + # sorted_events = [] + # event_ids = [event.event_id for event in events] + # event_map = {event.event_id: event for event in events} + # event_id_graph = await self.generateEventIdGraphFromEvents(events) + # for event_id in sorted_topologically(event_ids, event_id_graph): + # sorted_events.append(event_map[event_id]) + # sorted_events = reversed(sorted_events) logger.info( "backfill sorted_events=%s", diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4a4d35f77c..a569e8146a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, NamedTuple from prometheus_client import Counter, Gauge @@ -53,6 +53,14 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) +# All the info we need while iterating the DAG while backfilling +class BackfillQueueNavigationItem(NamedTuple): + depth: int + stream_ordering: int + event_id: str + type: str + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -987,6 +995,117 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) + async def get_connected_insertion_event_backfill_results( + self, event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + def _get_connected_insertion_event_backfill_results_txn(txn): + # Look for the "insertion" events connected to the given event_id + connected_insertion_event_query = """ + SELECT e.depth, e.stream_ordering, i.event_id, e.type FROM insertion_event_edges AS i + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + /* Find an insertion event which points via prev_events to the given event_id */ + WHERE i.insertion_prev_event_id = ? + LIMIT ? + """ + + txn.execute( + connected_insertion_event_query, + (event_id, limit), + ) + connected_insertion_event_id_results = txn.fetchall() + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in connected_insertion_event_id_results + ] + + return await self.db_pool.runInteraction( + "get_connected_insertion_event_backfill_results", + _get_connected_insertion_event_backfill_results_txn, + ) + + async def get_connected_batch_event_backfill_results( + self, insertion_event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + def _get_connected_batch_event_backfill_results_txn(txn): + # Find any batch connections of a given insertion event + batch_connection_query = """ + SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i + /* Find the batch that connects to the given insertion event */ + INNER JOIN batch_events AS c + ON i.next_batch_id = c.batch_id + /* Get the depth of the batch start event from the events table */ + INNER JOIN events AS e USING (event_id) + /* Find an insertion event which matches the given event_id */ + WHERE i.event_id = ? + LIMIT ? + """ + + # Find any batch connections for the given insertion event + txn.execute( + batch_connection_query, + (insertion_event_id, limit), + ) + batch_start_event_id_results = txn.fetchall() + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in batch_start_event_id_results + ] + + return await self.db_pool.runInteraction( + "get_connected_batch_event_backfill_results", + _get_connected_batch_event_backfill_results_txn, + ) + + async def get_connected_prev_event_backfill_results( + self, event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + def _get_connected_prev_event_backfill_results_txn(txn): + # Look for the prev_event_id connected to the given event_id + connected_prev_event_query = """ + SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges + /* Get the depth and stream_ordering of the prev_event_id from the events table */ + INNER JOIN events + ON prev_event_id = events.event_id + /* Look for an edge which matches the given event_id */ + WHERE event_edges.event_id = ? + AND event_edges.is_state = ? + /* Because we can have many events at the same depth, + * we want to also tie-break and sort on stream_ordering */ + ORDER BY depth DESC, stream_ordering DESC + LIMIT ? + """ + + txn.execute( + connected_prev_event_query, + (event_id, False, limit), + ) + prev_event_id_results = txn.fetchall() + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in prev_event_id_results + ] + + return await self.db_pool.runInteraction( + "get_connected_prev_event_backfill_results", + _get_connected_prev_event_backfill_results_txn, + ) + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit`