mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-02 23:44:12 +01:00
Merge pull request #139 from matrix-org/bugs/SYN-369
Fix race with cache invalidation. SYN-369
This commit is contained in:
commit
31049c4d72
4 changed files with 80 additions and 16 deletions
|
@ -31,7 +31,9 @@ import functools
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
|
DEBUG_CACHES = False
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -68,9 +70,20 @@ class Cache(object):
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.keylen = keylen
|
self.keylen = keylen
|
||||||
|
self.sequence = 0
|
||||||
|
self.thread = None
|
||||||
caches_by_name[name] = self.cache
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def check_thread(self):
|
||||||
|
expected_thread = self.thread
|
||||||
|
if expected_thread is None:
|
||||||
|
self.thread = threading.current_thread()
|
||||||
|
else:
|
||||||
|
if expected_thread is not threading.current_thread():
|
||||||
|
raise ValueError(
|
||||||
|
"Cache objects can only be accessed from the main thread"
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, *keyargs):
|
def get(self, *keyargs):
|
||||||
if len(keyargs) != self.keylen:
|
if len(keyargs) != self.keylen:
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
@ -82,6 +95,13 @@ class Cache(object):
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
|
|
||||||
|
def update(self, sequence, *args):
|
||||||
|
self.check_thread()
|
||||||
|
if self.sequence == sequence:
|
||||||
|
# Only update the cache if the caches sequence number matches the
|
||||||
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
|
self.prefill(*args)
|
||||||
|
|
||||||
def prefill(self, *args): # because I can't *keyargs, value
|
def prefill(self, *args): # because I can't *keyargs, value
|
||||||
keyargs = args[:-1]
|
keyargs = args[:-1]
|
||||||
value = args[-1]
|
value = args[-1]
|
||||||
|
@ -96,9 +116,12 @@ class Cache(object):
|
||||||
self.cache[keyargs] = value
|
self.cache[keyargs] = value
|
||||||
|
|
||||||
def invalidate(self, *keyargs):
|
def invalidate(self, *keyargs):
|
||||||
|
self.check_thread()
|
||||||
if len(keyargs) != self.keylen:
|
if len(keyargs) != self.keylen:
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
# Increment the sequence number so that any SELECT statements that
|
||||||
|
# raced with the INSERT don't update the cache (SYN-369)
|
||||||
|
self.sequence += 1
|
||||||
self.cache.pop(keyargs, None)
|
self.cache.pop(keyargs, None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,11 +151,26 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(self, *keyargs):
|
def wrapped(self, *keyargs):
|
||||||
try:
|
try:
|
||||||
defer.returnValue(cache.get(*keyargs))
|
cached_result = cache.get(*keyargs)
|
||||||
|
if DEBUG_CACHES:
|
||||||
|
actual_result = yield orig(self, *keyargs)
|
||||||
|
if actual_result != cached_result:
|
||||||
|
logger.error(
|
||||||
|
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||||
|
orig.__name__, keyargs,
|
||||||
|
cached_result, actual_result,
|
||||||
|
)
|
||||||
|
raise ValueError("Stale cache entry")
|
||||||
|
defer.returnValue(cached_result)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
# Get the sequence number of the cache before reading from the
|
||||||
|
# database so that we can tell if the cache is invalidated
|
||||||
|
# while the SELECT is executing (SYN-369)
|
||||||
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = yield orig(self, *keyargs)
|
ret = yield orig(self, *keyargs)
|
||||||
|
|
||||||
cache.prefill(*keyargs + (ret,))
|
cache.update(sequence, *keyargs + (ret,))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@ -147,12 +185,20 @@ class LoggingTransaction(object):
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
method."""
|
method."""
|
||||||
__slots__ = ["txn", "name", "database_engine"]
|
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
|
||||||
|
|
||||||
def __init__(self, txn, name, database_engine):
|
def __init__(self, txn, name, database_engine, after_callbacks):
|
||||||
object.__setattr__(self, "txn", txn)
|
object.__setattr__(self, "txn", txn)
|
||||||
object.__setattr__(self, "name", name)
|
object.__setattr__(self, "name", name)
|
||||||
object.__setattr__(self, "database_engine", database_engine)
|
object.__setattr__(self, "database_engine", database_engine)
|
||||||
|
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||||
|
|
||||||
|
def call_after(self, callback, *args):
|
||||||
|
"""Call the given callback on the main twisted thread after the
|
||||||
|
transaction has finished. Used to invalidate the caches on the
|
||||||
|
correct thread.
|
||||||
|
"""
|
||||||
|
self.after_callbacks.append((callback, args))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.txn, name)
|
return getattr(self.txn, name)
|
||||||
|
@ -294,6 +340,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
|
|
||||||
|
after_callbacks = []
|
||||||
|
|
||||||
def inner_func(conn, *args, **kwargs):
|
def inner_func(conn, *args, **kwargs):
|
||||||
with LoggingContext("runInteraction") as context:
|
with LoggingContext("runInteraction") as context:
|
||||||
if self.database_engine.is_connection_closed(conn):
|
if self.database_engine.is_connection_closed(conn):
|
||||||
|
@ -318,10 +366,10 @@ class SQLBaseStore(object):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
return func(
|
txn = LoggingTransaction(
|
||||||
LoggingTransaction(txn, name, self.database_engine),
|
txn, name, self.database_engine, after_callbacks
|
||||||
*args, **kwargs
|
|
||||||
)
|
)
|
||||||
|
return func(txn, *args, **kwargs)
|
||||||
except self.database_engine.module.OperationalError as e:
|
except self.database_engine.module.OperationalError as e:
|
||||||
# This can happen if the database disappears mid
|
# This can happen if the database disappears mid
|
||||||
# transaction.
|
# transaction.
|
||||||
|
@ -370,6 +418,8 @@ class SQLBaseStore(object):
|
||||||
result = yield self._db_pool.runWithConnection(
|
result = yield self._db_pool.runWithConnection(
|
||||||
inner_func, *args, **kwargs
|
inner_func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
for after_callback, after_args in after_callbacks:
|
||||||
|
after_callback(*after_args)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def cursor_to_dict(self, cursor):
|
def cursor_to_dict(self, cursor):
|
||||||
|
|
|
@ -330,7 +330,9 @@ class EventFederationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
txn.execute(query)
|
txn.execute(query)
|
||||||
|
|
||||||
self.get_latest_event_ids_in_room.invalidate(room_id)
|
txn.call_after(
|
||||||
|
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||||
|
)
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
def get_backfill_events(self, room_id, event_list, limit):
|
||||||
"""Get a list of Events for a given topic that occurred before (and
|
"""Get a list of Events for a given topic that occurred before (and
|
||||||
|
|
|
@ -93,7 +93,7 @@ class EventsStore(SQLBaseStore):
|
||||||
current_state=None):
|
current_state=None):
|
||||||
|
|
||||||
# Remove the any existing cache entries for the event_id
|
# Remove the any existing cache entries for the event_id
|
||||||
self._invalidate_get_event_cache(event.event_id)
|
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||||
|
|
||||||
if stream_ordering is None:
|
if stream_ordering is None:
|
||||||
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
||||||
|
@ -114,6 +114,13 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
for s in current_state:
|
for s in current_state:
|
||||||
|
if s.type == EventTypes.Member:
|
||||||
|
txn.call_after(
|
||||||
|
self.get_rooms_for_user.invalidate, s.state_key
|
||||||
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self.get_joined_hosts_for_room.invalidate, s.room_id
|
||||||
|
)
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
"current_state_events",
|
"current_state_events",
|
||||||
|
@ -281,7 +288,9 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.rejected:
|
if context.rejected:
|
||||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
self._store_rejections_txn(
|
||||||
|
txn, event.event_id, context.rejected
|
||||||
|
)
|
||||||
|
|
||||||
for hash_alg, hash_base64 in event.hashes.items():
|
for hash_alg, hash_base64 in event.hashes.items():
|
||||||
hash_bytes = decode_base64(hash_base64)
|
hash_bytes = decode_base64(hash_base64)
|
||||||
|
@ -293,7 +302,8 @@ class EventsStore(SQLBaseStore):
|
||||||
for alg, hash_base64 in prev_hashes.items():
|
for alg, hash_base64 in prev_hashes.items():
|
||||||
hash_bytes = decode_base64(hash_base64)
|
hash_bytes = decode_base64(hash_base64)
|
||||||
self._store_prev_event_hash_txn(
|
self._store_prev_event_hash_txn(
|
||||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
txn, event.event_id, prev_event_id, alg,
|
||||||
|
hash_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
for auth_id, _ in event.auth_events:
|
for auth_id, _ in event.auth_events:
|
||||||
|
@ -356,9 +366,11 @@ class EventsStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
def _store_redaction(self, txn, event):
|
def _store_redaction(self, txn, event):
|
||||||
# invalidate the cache for the redacted event
|
# invalidate the cache for the redacted event
|
||||||
self._invalidate_get_event_cache(event.redacts)
|
txn.call_after(self._invalidate_get_event_cache, event.redacts)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||||
(event.event_id, event.redacts)
|
(event.event_id, event.redacts)
|
||||||
|
|
|
@ -64,8 +64,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_rooms_for_user.invalidate(target_user_id)
|
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
|
||||||
self.get_joined_hosts_for_room.invalidate(event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
|
|
Loading…
Reference in a new issue