From e6a6c4fbab8f409a20422659dddc3b7437a2cc07 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 24 Jul 2019 16:37:50 +0100 Subject: [PATCH] split _get_events_from_db out of _enqueue_events --- synapse/storage/events_worker.py | 83 ++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 5b606bec7..6e5f1cf6e 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -343,13 +343,12 @@ class EventsWorkerStore(SQLBaseStore): log_ctx = LoggingContext.current_context() log_ctx.record_event_fetch(len(missing_events_ids)) - # Note that _enqueue_events is also responsible for turning db rows + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - # _enqueue_events is a bit of a rubbish name but naming is hard. - missing_events = yield self._enqueue_events( + missing_events = yield self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -458,43 +457,25 @@ class EventsWorkerStore(SQLBaseStore): self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks - def _enqueue_events(self, events, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. Args: - events (Iterable[str]): events to be fetched. + event_ids (Iterable[str]): The event_ids of the events to fetch + allow_rejected (bool): Whether to include rejected events Returns: - Deferred[Dict[str, _EventCacheEntry]]: map from event id to result. + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. """ - if not events: + if not event_ids: return {} - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append((events, events_d)) + row_map = yield self._enqueue_events(event_ids) - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - row_map = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) - - rows = (row_map.get(event_id) for event_id in events) + rows = (row_map.get(event_id) for event_id in event_ids) # filter out absent rows rows = filter(operator.truth, rows) @@ -521,6 +502,44 @@ class EventsWorkerStore(SQLBaseStore): return {e.event.event_id: e for e in res if e} + @defer.inlineCallbacks + def _enqueue_events(self, events): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. + """ + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append((events, events_d)) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.runWithConnection, self._do_fetch + ) + + logger.debug("Loading %d events: %s", len(events), events) + with PreserveLoggingContext(): + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) + + return row_map + def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database