Merge pull request #215 from matrix-org/erikj/cache_varargs_interface

Change Cache to not use *args in its interface
This commit is contained in:
Erik Johnston 2015-08-10 13:47:45 +01:00
commit 8c3a62b5c7
12 changed files with 73 additions and 72 deletions

View file

@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip) key = (user.to_string(), access_token, device_id, ip)
try: try:
last_seen = self.client_ip_last_seen.get(*key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, # It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely

View file

@ -57,6 +57,9 @@ cache_counter = metrics.register_cache(
) )
_CacheSentinel = object()
class Cache(object): class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True): def __init__(self, name, max_entries=1000, keylen=1, lru=True):
@ -83,45 +86,42 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, *keyargs): def get(self, key, default=_CacheSentinel):
if len(keyargs) != self.keylen: val = self.cache.get(key, _CacheSentinel)
raise ValueError("Expected a key to have %d items", self.keylen) if val is not _CacheSentinel:
if keyargs in self.cache:
cache_counter.inc_hits(self.name) cache_counter.inc_hits(self.name)
return self.cache[keyargs] return val
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args): if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args) self.prefill(key, value)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
def prefill(self, key, value):
if self.max_entries is not None: if self.max_entries is not None:
while len(self.cache) >= self.max_entries: while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False) self.cache.popitem(last=False)
self.cache[keyargs] = value self.cache[key] = value
def invalidate(self, *keyargs): def invalidate(self, key):
self.check_thread() self.check_thread()
if len(keyargs) != self.keylen: if not isinstance(key, tuple):
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("keyargs must be a tuple.")
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(key, None)
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
@ -168,20 +168,21 @@ class CacheDescriptor(object):
% (orig.__name__,) % (orig.__name__,)
) )
def __get__(self, obj, objtype=None): self.cache = Cache(
cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
lru=self.lru, lru=self.lru,
) )
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result_d = cache.get(*keyargs) cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe() observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
@ -191,7 +192,7 @@ class CacheDescriptor(object):
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
"Stale cache entry %s%r: cached: %r, actual %r", "Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs, self.orig.__name__, cache_key,
cached_result, actual_result, cached_result, actual_result,
) )
raise ValueError("Stale cache entry") raise ValueError("Stale cache entry")
@ -203,7 +204,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = self.cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
self.function_to_call, self.function_to_call,
@ -211,19 +212,19 @@ class CacheDescriptor(object):
) )
def onErr(f): def onErr(f):
cache.invalidate(*keyargs) self.cache.invalidate(cache_key)
return f return f
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=False) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, *(keyargs + [ret])) self.cache.update(sequence, cache_key, ret)
return ret.observe() return ret.observe()
wrapped.invalidate = cache.invalidate wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = cache.prefill wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.orig.__name__] = wrapped

View file

@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):

View file

@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room: for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View file

@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected: if not context.rejected:
txn.call_after( txn.call_after(
self.get_current_state_for_key.invalidate, self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key (event.room_id, event.type, event.state_key,)
) )
if event.type in [EventTypes.Name, EventTypes.Aliases]: if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_room_name_and_aliases.invalidate,
event.room_id (event.room_id,)
) )
self._simple_upsert_txn( self._simple_upsert_txn(
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
for get_prev_content in (False, True): for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted, self._get_event_cache.invalidate(
get_prev_content) (event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content (event_id, check_redacted, get_prev_content,)
) )
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
defer.returnValue(ev) defer.returnValue(ev)
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
return ev return ev

View file

@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate(server_name) self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):

View file

@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))

View file

@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )

View file

@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate(r) self.get_user_by_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_token(self, token):

View file

@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
) )
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached, cachedInlineCallbacks from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list)) defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, key, events): def _fetch_events_for_group(self, key, events):
return self._get_events( return self._get_events(
events, get_prev_content=False events, get_prev_content=False

View file

@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123) self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self): def test_invalidate(self):
self.cache.prefill("foo", 123) self.cache.prefill(("foo",), 123)
self.cache.invalidate("foo") self.cache.invalidate(("foo",))
failed = False failed = False
try: try:
self.cache.get("foo") self.cache.get(("foo",))
except KeyError: except KeyError:
failed = True failed = True
@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
a.func.invalidate("foo") a.func.invalidate(("foo",))
yield a.func("foo") yield a.func("foo")
@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
def func(self, key): def func(self, key):
return key return key
A().func.invalidate("what") A().func.invalidate(("what",))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a = A() a = A()
a.func.prefill("foo", ObservableDeferred(d)) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result) self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)