forked from MirrorHub/synapse
Handle cancellation in EventsWorkerStore._get_events_from_cache_or_db
(#12529)
Multiple calls to `EventsWorkerStore._get_events_from_cache_or_db` can reuse the same database fetch, which is initiated by the first call. Ensure that cancelling the first call doesn't cancel the other calls sharing the same database fetch. Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
parent
813d728d09
commit
8a87b4435a
3 changed files with 167 additions and 34 deletions
1
changelog.d/12529.misc
Normal file
1
changelog.d/12529.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Handle cancellation in `EventsWorkerStore._get_events_from_cache_or_db`.
|
|
@ -75,7 +75,7 @@ from synapse.storage.util.id_generators import (
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
@ -640,42 +640,57 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
missing_events_ids.difference_update(already_fetching_ids)
|
missing_events_ids.difference_update(already_fetching_ids)
|
||||||
|
|
||||||
if missing_events_ids:
|
if missing_events_ids:
|
||||||
log_ctx = current_context()
|
|
||||||
log_ctx.record_event_fetch(len(missing_events_ids))
|
|
||||||
|
|
||||||
# Add entries to `self._current_event_fetches` for each event we're
|
async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
|
||||||
# going to pull from the DB. We use a single deferred that resolves
|
"""Fetches the events in `missing_event_ids` from the database.
|
||||||
# to all the events we pulled from the DB (this will result in this
|
|
||||||
# function returning more events than requested, but that can happen
|
|
||||||
# already due to `_get_events_from_db`).
|
|
||||||
fetching_deferred: ObservableDeferred[
|
|
||||||
Dict[str, EventCacheEntry]
|
|
||||||
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
|
||||||
for event_id in missing_events_ids:
|
|
||||||
self._current_event_fetches[event_id] = fetching_deferred
|
|
||||||
|
|
||||||
# Note that _get_events_from_db is also responsible for turning db rows
|
Also creates entries in `self._current_event_fetches` to allow
|
||||||
# into FrozenEvents (via _get_event_from_row), which involves seeing if
|
concurrent `_get_events_from_cache_or_db` calls to reuse the same fetch.
|
||||||
# the events have been redacted, and if so pulling the redaction event out
|
"""
|
||||||
# of the database to check it.
|
log_ctx = current_context()
|
||||||
#
|
log_ctx.record_event_fetch(len(missing_events_ids))
|
||||||
try:
|
|
||||||
missing_events = await self._get_events_from_db(
|
|
||||||
missing_events_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
event_entry_map.update(missing_events)
|
# Add entries to `self._current_event_fetches` for each event we're
|
||||||
except Exception as e:
|
# going to pull from the DB. We use a single deferred that resolves
|
||||||
with PreserveLoggingContext():
|
# to all the events we pulled from the DB (this will result in this
|
||||||
fetching_deferred.errback(e)
|
# function returning more events than requested, but that can happen
|
||||||
raise e
|
# already due to `_get_events_from_db`).
|
||||||
finally:
|
fetching_deferred: ObservableDeferred[
|
||||||
# Ensure that we mark these events as no longer being fetched.
|
Dict[str, EventCacheEntry]
|
||||||
|
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||||
for event_id in missing_events_ids:
|
for event_id in missing_events_ids:
|
||||||
self._current_event_fetches.pop(event_id, None)
|
self._current_event_fetches[event_id] = fetching_deferred
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
# Note that _get_events_from_db is also responsible for turning db rows
|
||||||
fetching_deferred.callback(missing_events)
|
# into FrozenEvents (via _get_event_from_row), which involves seeing if
|
||||||
|
# the events have been redacted, and if so pulling the redaction event
|
||||||
|
# out of the database to check it.
|
||||||
|
#
|
||||||
|
try:
|
||||||
|
missing_events = await self._get_events_from_db(
|
||||||
|
missing_events_ids,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
fetching_deferred.errback(e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
# Ensure that we mark these events as no longer being fetched.
|
||||||
|
for event_id in missing_events_ids:
|
||||||
|
self._current_event_fetches.pop(event_id, None)
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
fetching_deferred.callback(missing_events)
|
||||||
|
|
||||||
|
return missing_events
|
||||||
|
|
||||||
|
# We must allow the database fetch to complete in the presence of
|
||||||
|
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
|
||||||
|
# reuse the same fetch.
|
||||||
|
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
|
||||||
|
get_missing_events_from_db()
|
||||||
|
)
|
||||||
|
event_entry_map.update(missing_events)
|
||||||
|
|
||||||
if already_fetching_deferreds:
|
if already_fetching_deferreds:
|
||||||
# Wait for the other event requests to finish and add their results
|
# Wait for the other event requests to finish and add their results
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Generator
|
from typing import Generator, Tuple
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from twisted.enterprise.adbapi import ConnectionPool
|
from twisted.enterprise.adbapi import ConnectionPool
|
||||||
from twisted.internet.defer import ensureDeferred
|
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.room_versions import EventFormatVersions, RoomVersions
|
from synapse.api.room_versions import EventFormatVersions, RoomVersions
|
||||||
|
@ -281,3 +282,119 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# This next event fetch should succeed
|
# This next event fetch should succeed
|
||||||
self.get_success(self.store.get_event(self.event_ids[0]))
|
self.get_success(self.store.get_event(self.event_ids[0]))
|
||||||
|
|
||||||
|
|
||||||
|
class GetEventCancellationTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Test cancellation of `get_event` calls."""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||||
|
self.store: EventsWorkerStore = hs.get_datastores().main
|
||||||
|
|
||||||
|
self.user = self.register_user("user", "pass")
|
||||||
|
self.token = self.login(self.user, "pass")
|
||||||
|
|
||||||
|
self.room = self.helper.create_room_as(self.user, tok=self.token)
|
||||||
|
|
||||||
|
res = self.helper.send(self.room, tok=self.token)
|
||||||
|
self.event_id = res["event_id"]
|
||||||
|
|
||||||
|
# Reset the event cache so the tests start with it empty
|
||||||
|
self.store._get_event_cache.clear()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def blocking_get_event_calls(
|
||||||
|
self,
|
||||||
|
) -> Generator[
|
||||||
|
Tuple["Deferred[None]", "Deferred[None]", "Deferred[None]"], None, None
|
||||||
|
]:
|
||||||
|
"""Starts two concurrent `get_event` calls for the same event.
|
||||||
|
|
||||||
|
Both `get_event` calls will use the same database fetch, which will be blocked
|
||||||
|
at the time this function returns.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
* A `Deferred` that unblocks the database fetch.
|
||||||
|
* A cancellable `Deferred` for the first `get_event` call.
|
||||||
|
* A cancellable `Deferred` for the second `get_event` call.
|
||||||
|
"""
|
||||||
|
# Patch `DatabasePool.runWithConnection` to block.
|
||||||
|
unblock: "Deferred[None]" = Deferred()
|
||||||
|
original_runWithConnection = self.store.db_pool.runWithConnection
|
||||||
|
|
||||||
|
async def runWithConnection(*args, **kwargs):
|
||||||
|
await unblock
|
||||||
|
return await original_runWithConnection(*args, **kwargs)
|
||||||
|
|
||||||
|
with mock.patch.object(
|
||||||
|
self.store.db_pool,
|
||||||
|
"runWithConnection",
|
||||||
|
new=runWithConnection,
|
||||||
|
):
|
||||||
|
ctx1 = LoggingContext("get_event1")
|
||||||
|
ctx2 = LoggingContext("get_event2")
|
||||||
|
|
||||||
|
async def get_event(ctx: LoggingContext) -> None:
|
||||||
|
with ctx:
|
||||||
|
await self.store.get_event(self.event_id)
|
||||||
|
|
||||||
|
get_event1 = ensureDeferred(get_event(ctx1))
|
||||||
|
get_event2 = ensureDeferred(get_event(ctx2))
|
||||||
|
|
||||||
|
# Both `get_event` calls ought to be blocked.
|
||||||
|
self.assertNoResult(get_event1)
|
||||||
|
self.assertNoResult(get_event2)
|
||||||
|
|
||||||
|
yield unblock, get_event1, get_event2
|
||||||
|
|
||||||
|
# Confirm that the two `get_event` calls shared the same database fetch.
|
||||||
|
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
|
||||||
|
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
|
||||||
|
|
||||||
|
def test_first_get_event_cancelled(self):
|
||||||
|
"""Test cancellation of the first `get_event` call sharing a database fetch.
|
||||||
|
|
||||||
|
The first `get_event` call is the one which initiates the fetch. We expect the
|
||||||
|
fetch to complete despite the cancellation. Furthermore, the first `get_event`
|
||||||
|
call must not abort before the fetch is complete, otherwise the fetch will be
|
||||||
|
using a finished logging context.
|
||||||
|
"""
|
||||||
|
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
||||||
|
# Cancel the first `get_event` call.
|
||||||
|
get_event1.cancel()
|
||||||
|
# The first `get_event` call must not abort immediately, otherwise its
|
||||||
|
# logging context will be finished while it is still in use by the database
|
||||||
|
# fetch.
|
||||||
|
self.assertNoResult(get_event1)
|
||||||
|
# The second `get_event` call must not be cancelled.
|
||||||
|
self.assertNoResult(get_event2)
|
||||||
|
|
||||||
|
# Unblock the database fetch.
|
||||||
|
unblock.callback(None)
|
||||||
|
# A `CancelledError` should be raised out of the first `get_event` call.
|
||||||
|
exc = self.get_failure(get_event1, CancelledError).value
|
||||||
|
self.assertIsInstance(exc, CancelledError)
|
||||||
|
# The second `get_event` call should complete successfully.
|
||||||
|
self.get_success(get_event2)
|
||||||
|
|
||||||
|
def test_second_get_event_cancelled(self):
|
||||||
|
"""Test cancellation of the second `get_event` call sharing a database fetch."""
|
||||||
|
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
|
||||||
|
# Cancel the second `get_event` call.
|
||||||
|
get_event2.cancel()
|
||||||
|
# The first `get_event` call must not be cancelled.
|
||||||
|
self.assertNoResult(get_event1)
|
||||||
|
# The second `get_event` call gets cancelled immediately.
|
||||||
|
exc = self.get_failure(get_event2, CancelledError).value
|
||||||
|
self.assertIsInstance(exc, CancelledError)
|
||||||
|
|
||||||
|
# Unblock the database fetch.
|
||||||
|
unblock.callback(None)
|
||||||
|
# The first `get_event` call should complete successfully.
|
||||||
|
self.get_success(get_event1)
|
||||||
|
|
Loading…
Reference in a new issue