mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 16:53:53 +01:00
Change Cache to not use *args in its interface
This commit is contained in:
parent
63b1eaf32c
commit
20addfa358
11 changed files with 69 additions and 67 deletions
|
@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
key = (user.to_string(), access_token, device_id, ip)
|
key = (user.to_string(), access_token, device_id, ip)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
last_seen = self.client_ip_last_seen.get(*key)
|
last_seen = self.client_ip_last_seen.get(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
last_seen = None
|
last_seen = None
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(*key + (now,))
|
self.client_ip_last_seen.prefill(key, now)
|
||||||
|
|
||||||
# It's safe not to lock here: a) no unique constraint,
|
# It's safe not to lock here: a) no unique constraint,
|
||||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||||
|
|
|
@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CacheSentinel = object()
|
||||||
|
|
||||||
|
|
||||||
class Cache(object):
|
class Cache(object):
|
||||||
|
|
||||||
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
||||||
|
@ -83,41 +86,38 @@ class Cache(object):
|
||||||
"Cache objects can only be accessed from the main thread"
|
"Cache objects can only be accessed from the main thread"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, *keyargs):
|
def get(self, keyargs, default=_CacheSentinel):
|
||||||
if len(keyargs) != self.keylen:
|
val = self.cache.get(keyargs, _CacheSentinel)
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
if val is not _CacheSentinel:
|
||||||
|
|
||||||
if keyargs in self.cache:
|
|
||||||
cache_counter.inc_hits(self.name)
|
cache_counter.inc_hits(self.name)
|
||||||
return self.cache[keyargs]
|
return val
|
||||||
|
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
raise KeyError()
|
|
||||||
|
|
||||||
def update(self, sequence, *args):
|
if default is _CacheSentinel:
|
||||||
|
raise KeyError()
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def update(self, sequence, keyargs, value):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if self.sequence == sequence:
|
if self.sequence == sequence:
|
||||||
# Only update the cache if the caches sequence number matches the
|
# Only update the cache if the caches sequence number matches the
|
||||||
# number that the cache had before the SELECT was started (SYN-369)
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
self.prefill(*args)
|
self.prefill(keyargs, value)
|
||||||
|
|
||||||
def prefill(self, *args): # because I can't *keyargs, value
|
|
||||||
keyargs = args[:-1]
|
|
||||||
value = args[-1]
|
|
||||||
|
|
||||||
if len(keyargs) != self.keylen:
|
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
|
||||||
|
|
||||||
|
def prefill(self, keyargs, value):
|
||||||
if self.max_entries is not None:
|
if self.max_entries is not None:
|
||||||
while len(self.cache) >= self.max_entries:
|
while len(self.cache) >= self.max_entries:
|
||||||
self.cache.popitem(last=False)
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
self.cache[keyargs] = value
|
self.cache[keyargs] = value
|
||||||
|
|
||||||
def invalidate(self, *keyargs):
|
def invalidate(self, keyargs):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if len(keyargs) != self.keylen:
|
if not isinstance(keyargs, tuple):
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise ValueError("keyargs must be a tuple.")
|
||||||
|
|
||||||
# Increment the sequence number so that any SELECT statements that
|
# Increment the sequence number so that any SELECT statements that
|
||||||
# raced with the INSERT don't update the cache (SYN-369)
|
# raced with the INSERT don't update the cache (SYN-369)
|
||||||
self.sequence += 1
|
self.sequence += 1
|
||||||
|
@ -168,20 +168,21 @@ class CacheDescriptor(object):
|
||||||
% (orig.__name__,)
|
% (orig.__name__,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
self.cache = Cache(
|
||||||
cache = Cache(
|
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
lru=self.lru,
|
lru=self.lru,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype=None):
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
try:
|
try:
|
||||||
cached_result_d = cache.get(*keyargs)
|
cached_result_d = self.cache.get(keyargs)
|
||||||
|
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
|
|
||||||
|
@ -202,7 +203,7 @@ class CacheDescriptor(object):
|
||||||
# Get the sequence number of the cache before reading from the
|
# Get the sequence number of the cache before reading from the
|
||||||
# database so that we can tell if the cache is invalidated
|
# database so that we can tell if the cache is invalidated
|
||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = cache.sequence
|
sequence = self.cache.sequence
|
||||||
|
|
||||||
ret = defer.maybeDeferred(
|
ret = defer.maybeDeferred(
|
||||||
self.function_to_call,
|
self.function_to_call,
|
||||||
|
@ -210,19 +211,19 @@ class CacheDescriptor(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def onErr(f):
|
def onErr(f):
|
||||||
cache.invalidate(*keyargs)
|
self.cache.invalidate(keyargs)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
ret = ObservableDeferred(ret, consumeErrors=False)
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
cache.update(sequence, *(keyargs + [ret]))
|
self.cache.update(sequence, keyargs, ret)
|
||||||
|
|
||||||
return ret.observe()
|
return ret.observe()
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = self.cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
wrapped.invalidate_all = self.cache.invalidate_all
|
||||||
wrapped.prefill = cache.prefill
|
wrapped.prefill = self.cache.prefill
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
desc="create_room_alias_association",
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
|
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
|
||||||
room_alias,
|
room_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
defer.returnValue(room_id)
|
defer.returnValue(room_id)
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias):
|
||||||
|
|
|
@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
for room_id in events_by_room:
|
for room_id in events_by_room:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
def get_backfill_events(self, room_id, event_list, limit):
|
||||||
|
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
|
||||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||||
|
|
|
@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
|
||||||
if current_state:
|
if current_state:
|
||||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||||
|
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
|
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
|
||||||
if not context.rejected:
|
if not context.rejected:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_current_state_for_key.invalidate,
|
self.get_current_state_for_key.invalidate,
|
||||||
event.room_id, event.type, event.state_key
|
(event.room_id, event.type, event.state_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_room_name_and_aliases.invalidate,
|
self.get_room_name_and_aliases.invalidate,
|
||||||
event.room_id
|
(event.room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
|
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
|
||||||
def _invalidate_get_event_cache(self, event_id):
|
def _invalidate_get_event_cache(self, event_id):
|
||||||
for check_redacted in (False, True):
|
for check_redacted in (False, True):
|
||||||
for get_prev_content in (False, True):
|
for get_prev_content in (False, True):
|
||||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
self._get_event_cache.invalidate(
|
||||||
get_prev_content)
|
(event_id, check_redacted, get_prev_content)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||||
get_prev_content=False, allow_rejected=False):
|
get_prev_content=False, allow_rejected=False):
|
||||||
|
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
|
||||||
for event_id in events:
|
for event_id in events:
|
||||||
try:
|
try:
|
||||||
ret = self._get_event_cache.get(
|
ret = self._get_event_cache.get(
|
||||||
event_id, check_redacted, get_prev_content
|
(event_id, check_redacted, get_prev_content,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if allow_rejected or not ret.rejected_reason:
|
if allow_rejected or not ret.rejected_reason:
|
||||||
|
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(ev)
|
defer.returnValue(ev)
|
||||||
|
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
|
||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
return ev
|
return ev
|
||||||
|
|
|
@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore):
|
||||||
desc="store_server_verify_key",
|
desc="store_server_verify_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_all_server_verify_keys.invalidate(server_name)
|
self.get_all_server_verify_keys.invalidate((server_name,))
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
|
|
|
@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
desc="set_presence_list_accepted",
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
|
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
|
||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
desc="del_presence_list",
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
|
|
|
@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
|
||||||
new_rule['priority'] = new_prio
|
new_rule['priority'] = new_prio
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
desc="delete_push_rule",
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_push_rules_for_user.invalidate(user_name)
|
self.get_push_rules_for_user.invalidate((user_name,))
|
||||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||||
|
@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
|
||||||
{'id': new_id},
|
{'id': new_id},
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.get_user_by_token.invalidate(r)
|
self.get_user_by_token.invalidate((r,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_token(self, token):
|
||||||
|
|
|
@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
|
|
|
@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
|
||||||
self.assertEquals(self.cache.get("foo"), 123)
|
self.assertEquals(self.cache.get("foo"), 123)
|
||||||
|
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
self.cache.prefill("foo", 123)
|
self.cache.prefill(("foo",), 123)
|
||||||
self.cache.invalidate("foo")
|
self.cache.invalidate(("foo",))
|
||||||
|
|
||||||
failed = False
|
failed = False
|
||||||
try:
|
try:
|
||||||
self.cache.get("foo")
|
self.cache.get(("foo",))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
failed = True
|
failed = True
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
a.func.invalidate("foo")
|
a.func.invalidate(("foo",))
|
||||||
|
|
||||||
yield a.func("foo")
|
yield a.func("foo")
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
def func(self, key):
|
def func(self, key):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
A().func.invalidate("what")
|
A().func.invalidate(("what",))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_max_entries(self):
|
def test_max_entries(self):
|
||||||
|
@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
||||||
a = A()
|
a = A()
|
||||||
|
|
||||||
a.func.prefill("foo", ObservableDeferred(d))
|
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
self.assertEquals(a.func("foo").result, d.result)
|
||||||
self.assertEquals(callcount[0], 0)
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
Loading…
Reference in a new issue