forked from MirrorHub/synapse
Handle cancellation in EventsWorkerStore._get_events_from_cache_or_db
(#12529)
Multiple calls to `EventsWorkerStore._get_events_from_cache_or_db` can reuse the same database fetch, which is initiated by the first call. Ensure that cancelling the first call doesn't cancel the other calls sharing the same database fetch. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
813d728d09
commit
8a87b4435a
3 changed files with 167 additions and 34 deletions
1
changelog.d/12529.misc
Normal file
1
changelog.d/12529.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Handle cancellation in `EventsWorkerStore._get_events_from_cache_or_db`.
|
|
@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
|
|||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
missing_events_ids.difference_update(already_fetching_ids)
|
||||
|
||||
if missing_events_ids:
|
||||
log_ctx = current_context()
|
||||
log_ctx.record_event_fetch(len(missing_events_ids))
|
||||
|
||||
# Add entries to `self._current_event_fetches` for each event we're
|
||||
# going to pull from the DB. We use a single deferred that resolves
|
||||
# to all the events we pulled from the DB (this will result in this
|
||||
# function returning more events than requested, but that can happen
|
||||
# already due to `_get_events_from_db`).
|
||||
fetching_deferred: ObservableDeferred[
|
||||
Dict[str, EventCacheEntry]
|
||||
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
for event_id in missing_events_ids:
|
||||
self._current_event_fetches[event_id] = fetching_deferred
|
||||
async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
|
||||
"""Fetches the events in `missing_event_ids` from the database.
|
||||
|
||||
# Note that _get_events_from_db is also responsible for turning db rows
|
||||
# into FrozenEvents (via _get_event_from_row), which involves seeing if
|
||||
# the events have been redacted, and if so pulling the redaction event out
|
||||
# of the database to check it.
|
||||
#
|
||||
try:
|
||||
missing_events = await self._get_events_from_db(
|
||||
missing_events_ids,
|
||||
)
|
||||
Also creates entries in `self._current_event_fetches` to allow
|
||||
concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
|
||||
"""
|
||||
log_ctx = current_context()
|
||||
log_ctx.record_event_fetch(len(missing_events_ids))
|
||||
|
||||
event_entry_map.update(missing_events)
|
||||
except Exception as e:
|
||||
with PreserveLoggingContext():
|
||||
fetching_deferred.errback(e)
|
||||
raise e
|
||||
finally:
|
||||
# Ensure that we mark these events as no longer being fetched.
|
||||
# Add entries to `self._current_event_fetches` for each event we're
|
||||
# going to pull from the DB. We use a single deferred that resolves
|
||||
# to all the events we pulled from the DB (this will result in this
|
||||
# function returning more events than requested, but that can happen
|
||||
# already due to `_get_events_from_db`).
|
||||
fetching_deferred: ObservableDeferred[
|
||||
Dict[str, EventCacheEntry]
|
||||
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
for event_id in missing_events_ids:
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
self._current_event_fetches[event_id] = fetching_deferred
|
||||
|
||||
with PreserveLoggingContext():
|
||||
fetching_deferred.callback(missing_events)
|
||||
# Note that _get_events_from_db is also responsible for turning db rows
|
||||
# into FrozenEvents (via _get_event_from_row), which involves seeing if
|
||||
# the events have been redacted, and if so pulling the redaction event
|
||||
# out of the database to check it.
|
||||
#
|
||||
try:
|
||||
missing_events = await self._get_events_from_db(
|
||||
missing_events_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
with PreserveLoggingContext():
|
||||
fetching_deferred.errback(e)
|
||||
raise e
|
||||
finally:
|
||||
# Ensure that we mark these events as no longer being fetched.
|
||||
for event_id in missing_events_ids:
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
fetching_deferred.callback(missing_events)
|
||||
|
||||
return missing_events
|
||||
|
||||
# We must allow the database fetch to complete in the presence of
|
||||
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
|
||||
# reuse the same fetch.
|
||||
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
|
||||
get_missing_events_from_db()
|
||||
)
|
||||
event_entry_map.update(missing_events)
|
||||
|
||||
if already_fetching_deferreds:
|
||||
# Wait for the other event requests to finish and add their results
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from typing import Generator, Tuple
|
||||
from unittest import mock
|
||||
|
||||
from twisted.enterprise.adbapi import ConnectionPool
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersions
|
||||
|
@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# This next event fetch should succeed
|
||||
self.get_success(self.store.get_event(self.event_ids[0]))
|
||||
|
||||
|
||||
class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||
"""Test cancellation of `get_event` calls."""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.token = self.login(self.user, "pass")
|
||||
|
||||
self.room = self.helper.create_room_as(self.user, tok=self.token)
|
||||
|
||||
res = self.helper.send(self.room, tok=self.token)
|
||||
self.event_id = res["event_id"]
|
||||
|
||||
# Reset the event cache so the tests start with it empty
|
||||
self.store._get_event_cache.clear()
|
||||
|
||||
@contextmanager
|
||||
def blocking_get_event_calls(
|
||||
self,
|
||||
) -> Generator[
|
||||
Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
|
||||
]:
|
||||
"""Starts two concurrent `get_event` calls for the same event.
|
||||
|
||||
Both `get_event` calls will use the same database fetch, which will be blocked
|
||||
at the time this function returns.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
* A `Deferred` that unblocks the database fetch.
|
||||
* A cancellable `Deferred` for the first `get_event` call.
|
||||
* A cancellable `Deferred` for the second `get_event` call.
|
||||
"""
|
||||
# Patch `DatabasePool.runWithConnection` to block.
|
||||
unblock: "Deferred[None]" = Deferred()
|
||||
original_runWithConnection = self.store.db_pool.runWithConnection
|
||||
|
||||
async def runWithConnection(*args, **kwargs):
|
||||
await unblock
|
||||
return await original_runWithConnection(*args, **kwargs)
|
||||
|
||||
with mock.patch.object(
|
||||
self.store.db_pool,
|
||||
"runWithConnection",
|
||||
new=runWithConnection,
|
||||
):
|
||||
ctx1 = LoggingContext("get_event1")
|
||||
ctx2 = LoggingContext("get_event2")
|
||||
|
||||
async def get_event(ctx: LoggingContext) -> None:
|
||||
with ctx:
|
||||
await self.store.get_event(self.event_id)
|
||||
|
||||
get_event1 = ensureDeferred(get_event(ctx1))
|
||||
get_event2 = ensureDeferred(get_event(ctx2))
|
||||
|
||||
# Both `get_event` calls ought to be blocked.
|
||||
self.assertNoResult(get_event1)
|
||||
self.assertNoResult(get_event2)
|
||||
|
||||
yield unblock, get_event1, get_event2
|
||||
|
||||
# Confirm that the two `get_event` calls shared the same database fetch.
|
||||
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
|
||||
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
|
||||
|
||||
def test_first_get_event_cancelled(self):
|
||||
"""Test cancellation of the first `get_event` call sharing a database fetch.
|
||||
|
||||
The first `get_event` call is the one which initiates the fetch. We expect the
|
||||
fetch to complete despite the cancellation. Furthermore, the first `get_event`
|
||||
call must not abort before the fetch is complete, otherwise the fetch will be
|
||||
using a finished logging context.
|
||||
"""
|
||||
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
||||
# Cancel the first `get_event` call.
|
||||
get_event1.cancel()
|
||||
# The first `get_event` call must not abort immediately, otherwise its
|
||||
# logging context will be finished while it is still in use by the database
|
||||
# fetch.
|
||||
self.assertNoResult(get_event1)
|
||||
# The second `get_event` call must not be cancelled.
|
||||
self.assertNoResult(get_event2)
|
||||
|
||||
# Unblock the database fetch.
|
||||
unblock.callback(None)
|
||||
# A `CancelledError` should be raised out of the first `get_event` call.
|
||||
exc = self.get_failure(get_event1, CancelledError).value
|
||||
self.assertIsInstance(exc, CancelledError)
|
||||
# The second `get_event` call should complete successfully.
|
||||
self.get_success(get_event2)
|
||||
|
||||
def test_second_get_event_cancelled(self):
|
||||
"""Test cancellation of the second `get_event` call sharing a database fetch."""
|
||||
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
||||
# Cancel the second `get_event` call.
|
||||
get_event2.cancel()
|
||||
# The first `get_event` call must not be cancelled.
|
||||
self.assertNoResult(get_event1)
|
||||
# The second `get_event` call gets cancelled immediately.
|
||||
exc = self.get_failure(get_event2, CancelledError).value
|
||||
self.assertIsInstance(exc, CancelledError)
|
||||
|
||||
# Unblock the database fetch.
|
||||
unblock.callback(None)
|
||||
# The first `get_event` call should complete successfully.
|
||||
self.get_success(get_event1)
|
||||
|
|
Loading…
Reference in a new issue