forked from MirrorHub/synapse
Pull out the cache logic from the @cached wrapper into its own class we can reuse
This commit is contained in:
parent
b1022ed8b5
commit
0f86312c4c
2 changed files with 87 additions and 36 deletions
|
@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
|
|||
)
|
||||
|
||||
|
||||
class Cache(object):
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1):
|
||||
self.cache = OrderedDict()
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if keyargs in self.cache:
|
||||
cache_counter.inc_hits(self.name)
|
||||
return self.cache[keyargs]
|
||||
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
while len(self.cache) > self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[keyargs] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
self.cache.pop(keyargs, None)
|
||||
|
||||
|
||||
# TODO(paul):
|
||||
# * consider other eviction strategies - LRU?
|
||||
def cached(max_entries=1000, num_args=1):
|
||||
|
@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
|
|||
calling the calculation function.
|
||||
"""
|
||||
def wrap(orig):
|
||||
cache = OrderedDict()
|
||||
name = orig.__name__
|
||||
|
||||
caches_by_name[name] = cache
|
||||
|
||||
def prefill(*args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
|
||||
while len(cache) > max_entries:
|
||||
cache.popitem(last=False)
|
||||
|
||||
cache[keyargs] = value
|
||||
cache = Cache(
|
||||
name=orig.__name__,
|
||||
max_entries=max_entries,
|
||||
keylen=num_args,
|
||||
)
|
||||
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, *keyargs):
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
try:
|
||||
defer.returnValue(cache.get(*keyargs))
|
||||
except KeyError:
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
if keyargs in cache:
|
||||
cache_counter.inc_hits(name)
|
||||
defer.returnValue(cache[keyargs])
|
||||
cache.prefill(*keyargs + (ret,))
|
||||
|
||||
cache_counter.inc_misses(name)
|
||||
ret = yield orig(self, *keyargs)
|
||||
defer.returnValue(ret)
|
||||
|
||||
prefill(*keyargs + (ret,))
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
def invalidate(*keyargs):
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
|
||||
cache.pop(keyargs, None)
|
||||
|
||||
wrapped.invalidate = invalidate
|
||||
wrapped.prefill = prefill
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.prefill = cache.prefill
|
||||
return wrapped
|
||||
|
||||
return wrap
|
||||
|
|
|
@ -17,7 +17,39 @@
|
|||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import cached
|
||||
from synapse.storage._base import Cache, cached
|
||||
|
||||
|
||||
class CacheTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cache = Cache("test")
|
||||
|
||||
def test_empty(self):
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
def test_hit(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
|
||||
self.assertEquals(self.cache.get("foo"), 123)
|
||||
|
||||
def test_invalidate(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
self.cache.invalidate("foo")
|
||||
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
|
||||
class CacheDecoratorTestCase(unittest.TestCase):
|
||||
|
|
Loading…
Reference in a new issue