Improve event caching code (#10119)

Ensure we only load an event from the DB once when the same event is requested multiple times at once.
This commit is contained in:
Erik Johnston 2021-08-04 13:54:51 +01:00 committed by GitHub
parent 11540be55e
commit c37dad67ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 158 additions and 43 deletions

1
changelog.d/10119.misc Normal file
View file

@ -0,0 +1 @@
Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.

View file

@ -14,7 +14,6 @@
import logging import logging
import threading import threading
from collections import namedtuple
from typing import ( from typing import (
Collection, Collection,
Container, Container,
@ -27,6 +26,7 @@ from typing import (
overload, overload,
) )
import attr
from constantly import NamedConstant, Names from constantly import NamedConstant, Names
from typing_extensions import Literal from typing_extensions import Literal
@ -42,7 +42,11 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context from synapse.logging.context import (
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
class EventRedactBehaviour(Names): class EventRedactBehaviour(Names):
@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore):
max_size=hs.config.caches.event_cache_size, max_size=hs.config.caches.event_cache_size,
) )
# Map from event ID to a deferred that will result in a map from event
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore):
return events return events
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database. """Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups. If events are pulled from the database, they will be cached for future lookups.
@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore):
Args: Args:
event_ids (Iterable[str]): The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events. If False, allow_rejected: Whether to include rejected events. If False,
rejected events are omitted from the response. rejected events are omitted from the response.
Returns: Returns:
Dict[str, _EventCacheEntry]: map from event id to result
map from event id to result
""" """
event_entry_map = self._get_events_from_cache( event_entry_map = self._get_events_from_cache(
event_ids, allow_rejected=allow_rejected event_ids,
) )
missing_events_ids = [e for e in event_ids if e not in event_entry_map] missing_events_ids = {e for e in event_ids if e not in event_entry_map}
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
already_fetching: Dict[str, defer.Deferred] = {}
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
already_fetching[event_id] = deferred.observe()
missing_events_ids.difference_update(already_fetching)
if missing_events_ids: if missing_events_ids:
log_ctx = current_context() log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids)) log_ctx.record_event_fetch(len(missing_events_ids))
# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
# Note that _get_events_from_db is also responsible for turning db rows # Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if # into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out # the events have been redacted, and if so pulling the redaction event out
# of the database to check it. # of the database to check it.
# #
missing_events = await self._get_events_from_db( try:
missing_events_ids, allow_rejected=allow_rejected missing_events = await self._get_events_from_db(
) missing_events_ids,
)
event_entry_map.update(missing_events) event_entry_map.update(missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
if already_fetching:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
already_fetching.values(),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
event_entry_map.update(result)
if not allow_rejected:
event_entry_map = {
event_id: entry
for event_id, entry in event_entry_map.items()
if not entry.event.rejected_reason
}
return event_entry_map return event_entry_map
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,)) self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): def _get_events_from_cache(
"""Fetch events from the caches self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
Args: Args:
events (Iterable[str]): list of event_ids to fetch events: list of event_ids to fetch
allow_rejected (bool): Whether to return events that were rejected update_metrics: Whether to update the cache hit ratio metrics
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
""" """
event_map = {} event_map = {}
@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore):
if not ret: if not ret:
continue continue
if allow_rejected or not ret.event.rejected_reason: event_map[event_id] = ret
event_map[event_id] = ret
else:
event_map[event_id] = None
return event_map return event_map
@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext(): with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e) self.hs.get_reactor().callFromThread(fire, event_list, e)
async def _get_events_from_db(self, event_ids, allow_rejected=False): async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database. """Fetch a bunch of events from the database.
May return rejected events.
Returned events will be added to the cache for future lookups. Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response. Unknown events are omitted from the response.
Args: Args:
event_ids (Iterable[str]): The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns: Returns:
Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which
map from event id to result. May return extra events which weren't asked for.
weren't asked for.
""" """
fetched_events = {} fetched_events = {}
events_to_fetch = event_ids events_to_fetch = event_ids
@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason = row["rejected_reason"] rejected_reason = row["rejected_reason"]
if not allow_rejected and rejected_reason:
continue
# If the event or metadata cannot be parsed, log the error and act # If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown. # as if the event is unknown.
try: try:

View file

@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off # We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we # the hit ratio counts. After all, we don't populate the cache if we
# miss it here # miss it here
event_map = self._get_events_from_cache( event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
member_event_ids, allow_rejected=False, update_metrics=False
)
missing_member_event_ids = [] missing_member_event_ids = []
for event_id in member_event_ids: for event_id in member_event_ids:
ev_entry = event_map.get(event_id) ev_entry = event_map.get(event_id)
if ev_entry: if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN: if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo( users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None), display_name=ev_entry.event.content.get("displayname", None),

View file

@ -14,7 +14,10 @@
import json import json
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest from tests import unittest
@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.store.have_seen_events("room1", ["event10"])) res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"}) self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store: EventsWorkerStore = hs.get_datastore()
self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")
self.room = self.helper.create_room_as(self.user, tok=self.token)
res = self.helper.send(self.room, tok=self.token)
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
self.store._get_event_cache.clear()
def test_simple(self):
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
"""
with LoggingContext("test") as ctx:
d = yieldable_gather_results(
self.store.get_event, [self.event_id, self.event_id]
)
self.get_success(d)
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)