mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 04:52:26 +01:00
Safe async event cache (#13308)
Fix race conditions in the async cache invalidation logic, by separating the async & local invalidation calls and ensuring any async call i executed first. Signed off by Nick @ Beeper (@Fizzadar).
This commit is contained in:
parent
7864f33e28
commit
2ee0b6ef4b
8 changed files with 102 additions and 21 deletions
1
changelog.d/13308.misc
Normal file
1
changelog.d/13308.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).
|
|
@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||||
where they may not have the cache.
|
where they may not have the cache.
|
||||||
|
|
||||||
|
Note that this function does not invalidate any remote caches, only the
|
||||||
|
local in-memory ones. Any remote invalidation must be performed before
|
||||||
|
calling this.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_name
|
cache_name
|
||||||
key: Entry to invalidate. If None then invalidates the entire
|
key: Entry to invalidate. If None then invalidates the entire
|
||||||
|
@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
if key is None:
|
if key is None:
|
||||||
cache.invalidate_all()
|
cache.invalidate_all()
|
||||||
else:
|
else:
|
||||||
cache.invalidate(tuple(key))
|
# Prefer any local-only invalidation method. Invalidating any non-local
|
||||||
|
# cache must be be done before this.
|
||||||
|
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||||
|
invalidate_method(tuple(key))
|
||||||
|
|
||||||
|
|
||||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||||
|
|
|
@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
from synapse.util.async_helpers import delay_cancellation
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
|
||||||
*,
|
*,
|
||||||
txn_name: Optional[str] = None,
|
txn_name: Optional[str] = None,
|
||||||
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||||
|
async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
|
||||||
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||||
) -> "LoggingTransaction":
|
) -> "LoggingTransaction":
|
||||||
if not txn_name:
|
if not txn_name:
|
||||||
|
@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
|
||||||
name=txn_name,
|
name=txn_name,
|
||||||
database_engine=self.engine,
|
database_engine=self.engine,
|
||||||
after_callbacks=after_callbacks,
|
after_callbacks=after_callbacks,
|
||||||
|
async_after_callbacks=async_after_callbacks,
|
||||||
exception_callbacks=exception_callbacks,
|
exception_callbacks=exception_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
|
||||||
|
|
||||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||||
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
||||||
|
_AsyncCallbackListEntry = Tuple[
|
||||||
|
Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
|
||||||
|
]
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
@ -227,6 +233,10 @@ class LoggingTransaction:
|
||||||
that have been added by `call_after` which should be run on
|
that have been added by `call_after` which should be run on
|
||||||
successful completion of the transaction. None indicates that no
|
successful completion of the transaction. None indicates that no
|
||||||
callbacks should be allowed to be scheduled to run.
|
callbacks should be allowed to be scheduled to run.
|
||||||
|
async_after_callbacks: A list that asynchronous callbacks will be appended
|
||||||
|
to by `async_call_after` which should run, before after_callbacks, on
|
||||||
|
successful completion of the transaction. None indicates that no
|
||||||
|
callbacks should be allowed to be scheduled to run.
|
||||||
exception_callbacks: A list that callbacks will be appended
|
exception_callbacks: A list that callbacks will be appended
|
||||||
to that have been added by `call_on_exception` which should be run
|
to that have been added by `call_on_exception` which should be run
|
||||||
if transaction ends with an error. None indicates that no callbacks
|
if transaction ends with an error. None indicates that no callbacks
|
||||||
|
@ -238,6 +248,7 @@ class LoggingTransaction:
|
||||||
"name",
|
"name",
|
||||||
"database_engine",
|
"database_engine",
|
||||||
"after_callbacks",
|
"after_callbacks",
|
||||||
|
"async_after_callbacks",
|
||||||
"exception_callbacks",
|
"exception_callbacks",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -247,12 +258,14 @@ class LoggingTransaction:
|
||||||
name: str,
|
name: str,
|
||||||
database_engine: BaseDatabaseEngine,
|
database_engine: BaseDatabaseEngine,
|
||||||
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||||
|
async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
|
||||||
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||||
):
|
):
|
||||||
self.txn = txn
|
self.txn = txn
|
||||||
self.name = name
|
self.name = name
|
||||||
self.database_engine = database_engine
|
self.database_engine = database_engine
|
||||||
self.after_callbacks = after_callbacks
|
self.after_callbacks = after_callbacks
|
||||||
|
self.async_after_callbacks = async_after_callbacks
|
||||||
self.exception_callbacks = exception_callbacks
|
self.exception_callbacks = exception_callbacks
|
||||||
|
|
||||||
def call_after(
|
def call_after(
|
||||||
|
@ -277,6 +290,28 @@ class LoggingTransaction:
|
||||||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||||
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def async_call_after(
|
||||||
|
self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Call the given asynchronous callback on the main twisted thread after
|
||||||
|
the transaction has finished (but before those added in `call_after`).
|
||||||
|
|
||||||
|
Mostly used to invalidate remote caches after transactions.
|
||||||
|
|
||||||
|
Note that transactions may be retried a few times if they encounter database
|
||||||
|
errors such as serialization failures. Callbacks given to `async_call_after`
|
||||||
|
will accumulate across transaction attempts and will _all_ be called once a
|
||||||
|
transaction attempt succeeds, regardless of whether previous transaction
|
||||||
|
attempts failed. Otherwise, if all transaction attempts fail, all
|
||||||
|
`call_on_exception` callbacks will be run instead.
|
||||||
|
"""
|
||||||
|
# if self.async_after_callbacks is None, that means that whatever constructed the
|
||||||
|
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||||
|
# is not the case.
|
||||||
|
assert self.async_after_callbacks is not None
|
||||||
|
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||||
|
self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||||
|
|
||||||
def call_on_exception(
|
def call_on_exception(
|
||||||
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -574,6 +609,7 @@ class DatabasePool:
|
||||||
conn: LoggingDatabaseConnection,
|
conn: LoggingDatabaseConnection,
|
||||||
desc: str,
|
desc: str,
|
||||||
after_callbacks: List[_CallbackListEntry],
|
after_callbacks: List[_CallbackListEntry],
|
||||||
|
async_after_callbacks: List[_AsyncCallbackListEntry],
|
||||||
exception_callbacks: List[_CallbackListEntry],
|
exception_callbacks: List[_CallbackListEntry],
|
||||||
func: Callable[Concatenate[LoggingTransaction, P], R],
|
func: Callable[Concatenate[LoggingTransaction, P], R],
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
|
@ -597,6 +633,7 @@ class DatabasePool:
|
||||||
conn
|
conn
|
||||||
desc
|
desc
|
||||||
after_callbacks
|
after_callbacks
|
||||||
|
async_after_callbacks
|
||||||
exception_callbacks
|
exception_callbacks
|
||||||
func
|
func
|
||||||
*args
|
*args
|
||||||
|
@ -659,6 +696,7 @@ class DatabasePool:
|
||||||
cursor = conn.cursor(
|
cursor = conn.cursor(
|
||||||
txn_name=name,
|
txn_name=name,
|
||||||
after_callbacks=after_callbacks,
|
after_callbacks=after_callbacks,
|
||||||
|
async_after_callbacks=async_after_callbacks,
|
||||||
exception_callbacks=exception_callbacks,
|
exception_callbacks=exception_callbacks,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -798,6 +836,7 @@ class DatabasePool:
|
||||||
|
|
||||||
async def _runInteraction() -> R:
|
async def _runInteraction() -> R:
|
||||||
after_callbacks: List[_CallbackListEntry] = []
|
after_callbacks: List[_CallbackListEntry] = []
|
||||||
|
async_after_callbacks: List[_AsyncCallbackListEntry] = []
|
||||||
exception_callbacks: List[_CallbackListEntry] = []
|
exception_callbacks: List[_CallbackListEntry] = []
|
||||||
|
|
||||||
if not current_context():
|
if not current_context():
|
||||||
|
@ -809,6 +848,7 @@ class DatabasePool:
|
||||||
self.new_transaction,
|
self.new_transaction,
|
||||||
desc,
|
desc,
|
||||||
after_callbacks,
|
after_callbacks,
|
||||||
|
async_after_callbacks,
|
||||||
exception_callbacks,
|
exception_callbacks,
|
||||||
func,
|
func,
|
||||||
*args,
|
*args,
|
||||||
|
@ -817,15 +857,17 @@ class DatabasePool:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We order these assuming that async functions call out to external
|
||||||
|
# systems (e.g. to invalidate a cache) and the sync functions make these
|
||||||
|
# changes on any local in-memory caches/similar, and thus must be second.
|
||||||
|
for async_callback, async_args, async_kwargs in async_after_callbacks:
|
||||||
|
await async_callback(*async_args, **async_kwargs)
|
||||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||||
await maybe_awaitable(after_callback(*after_args, **after_kwargs))
|
after_callback(*after_args, **after_kwargs)
|
||||||
|
|
||||||
return cast(R, result)
|
return cast(R, result)
|
||||||
except Exception:
|
except Exception:
|
||||||
for exception_callback, after_args, after_kwargs in exception_callbacks:
|
for exception_callback, after_args, after_kwargs in exception_callbacks:
|
||||||
await maybe_awaitable(
|
exception_callback(*after_args, **after_kwargs)
|
||||||
exception_callback(*after_args, **after_kwargs)
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# To handle cancellation, we ensure that `after_callback`s and
|
# To handle cancellation, we ensure that `after_callback`s and
|
||||||
|
|
|
@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||||
# changed its content in the database. We can't call
|
# changed its content in the database. We can't call
|
||||||
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
|
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
|
||||||
# right type.
|
# right type.
|
||||||
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
|
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||||
# Send that invalidation to replication so that other workers also invalidate
|
# Send that invalidation to replication so that other workers also invalidate
|
||||||
# the event cache.
|
# the event cache.
|
||||||
self._send_invalidation_to_replication(
|
self._send_invalidation_to_replication(
|
||||||
|
|
|
@ -1293,7 +1293,7 @@ class PersistEventsStore:
|
||||||
depth_updates: Dict[str, int] = {}
|
depth_updates: Dict[str, int] = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||||
# Then update the `stream_ordering` position to mark the latest
|
# Then update the `stream_ordering` position to mark the latest
|
||||||
# event as the front of the room. This should not be done for
|
# event as the front of the room. This should not be done for
|
||||||
# backfilled events because backfilled events have negative
|
# backfilled events because backfilled events have negative
|
||||||
|
@ -1675,7 +1675,7 @@ class PersistEventsStore:
|
||||||
(cache_entry.event.event_id,), cache_entry
|
(cache_entry.event.event_id,), cache_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(prefill)
|
txn.async_call_after(prefill)
|
||||||
|
|
||||||
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||||
"""Invalidate the caches for the redacted event.
|
"""Invalidate the caches for the redacted event.
|
||||||
|
@ -1684,7 +1684,7 @@ class PersistEventsStore:
|
||||||
_invalidate_caches_for_event.
|
_invalidate_caches_for_event.
|
||||||
"""
|
"""
|
||||||
assert event.redacts is not None
|
assert event.redacts is not None
|
||||||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
|
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
|
||||||
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
||||||
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
||||||
|
|
||||||
|
|
|
@ -712,17 +712,41 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return event_entry_map
|
return event_entry_map
|
||||||
|
|
||||||
async def _invalidate_get_event_cache(self, event_id: str) -> None:
|
def invalidate_get_event_cache_after_txn(
|
||||||
# First we invalidate the asynchronous cache instance. This may include
|
self, txn: LoggingTransaction, event_id: str
|
||||||
# out-of-process caches such as Redis/memcache. Once complete we can
|
) -> None:
|
||||||
# invalidate any in memory cache. The ordering is important here to
|
"""
|
||||||
# ensure we don't pull in any remote invalid value after we invalidate
|
Prepares a database transaction to invalidate the get event cache for a given
|
||||||
# the in-memory cache.
|
event ID when executed successfully. This is achieved by attaching two callbacks
|
||||||
|
to the transaction, one to invalidate the async cache and one for the in memory
|
||||||
|
sync cache (importantly called in that order).
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
txn: the database transaction to attach the callbacks to
|
||||||
|
event_id: the event ID to be invalidated from caches
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
|
||||||
|
txn.call_after(self._invalidate_local_get_event_cache, event_id)
|
||||||
|
|
||||||
|
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Invalidates an event in the asyncronous get event cache, which may be remote.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
event_id: the event ID to invalidate
|
||||||
|
"""
|
||||||
|
|
||||||
await self._get_event_cache.invalidate((event_id,))
|
await self._get_event_cache.invalidate((event_id,))
|
||||||
self._event_ref.pop(event_id, None)
|
|
||||||
self._current_event_fetches.pop(event_id, None)
|
|
||||||
|
|
||||||
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
|
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Invalidates an event in local in-memory get event caches.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
event_id: the event ID to invalidate
|
||||||
|
"""
|
||||||
|
|
||||||
self._get_event_cache.invalidate_local((event_id,))
|
self._get_event_cache.invalidate_local((event_id,))
|
||||||
self._event_ref.pop(event_id, None)
|
self._event_ref.pop(event_id, None)
|
||||||
self._current_event_fetches.pop(event_id, None)
|
self._current_event_fetches.pop(event_id, None)
|
||||||
|
@ -958,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
|
|
||||||
row_dict = self.db_pool.new_transaction(
|
row_dict = self.db_pool.new_transaction(
|
||||||
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
|
conn,
|
||||||
|
"do_fetch",
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
self._fetch_event_rows,
|
||||||
|
events_to_fetch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
# We only want to resolve deferreds from the main thread
|
||||||
|
|
|
@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
||||||
"initialise_mau_threepids",
|
"initialise_mau_threepids",
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
|
[],
|
||||||
self._initialise_reserved_users,
|
self._initialise_reserved_users,
|
||||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||||
)
|
)
|
||||||
|
|
|
@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.have_seen_event, (room_id, event_id)
|
txn, self.have_seen_event, (room_id, event_id)
|
||||||
)
|
)
|
||||||
txn.call_after(self._invalidate_get_event_cache, event_id)
|
self.invalidate_get_event_cache_after_txn(txn, event_id)
|
||||||
|
|
||||||
logger.info("[purge] done")
|
logger.info("[purge] done")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue