Pull out the cache logic from the @cached wrapper into its own class we can reuse

This commit is contained in:
Paul "LeoNerd" Evans 2015-03-20 18:13:49 +00:00
parent b1022ed8b5
commit 0f86312c4c
2 changed files with 87 additions and 36 deletions

View file

@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
) )
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1):
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
caches_by_name[name] = self.cache
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
cache_counter.inc_misses(self.name)
raise KeyError()
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
while len(self.cache) > self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
def invalidate(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
self.cache.pop(keyargs, None)
# TODO(paul): # TODO(paul):
# * consider other eviction strategies - LRU? # * consider other eviction strategies - LRU?
def cached(max_entries=1000, num_args=1): def cached(max_entries=1000, num_args=1):
@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
calling the calculation function. calling the calculation function.
""" """
def wrap(orig): def wrap(orig):
cache = OrderedDict() cache = Cache(
name = orig.__name__ name=orig.__name__,
max_entries=max_entries,
caches_by_name[name] = cache keylen=num_args,
)
def prefill(*args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != num_args:
raise ValueError("Expected a call to have %d arguments", num_args)
while len(cache) > max_entries:
cache.popitem(last=False)
cache[keyargs] = value
@functools.wraps(orig) @functools.wraps(orig)
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, *keyargs): def wrapped(self, *keyargs):
if len(keyargs) != num_args: try:
raise ValueError("Expected a call to have %d arguments", num_args) defer.returnValue(cache.get(*keyargs))
except KeyError:
ret = yield orig(self, *keyargs)
if keyargs in cache: cache.prefill(*keyargs + (ret,))
cache_counter.inc_hits(name)
defer.returnValue(cache[keyargs])
cache_counter.inc_misses(name) defer.returnValue(ret)
ret = yield orig(self, *keyargs)
prefill(*keyargs + (ret,)) wrapped.invalidate = cache.invalidate
wrapped.prefill = cache.prefill
defer.returnValue(ret)
def invalidate(*keyargs):
if len(keyargs) != num_args:
raise ValueError("Expected a call to have %d arguments", num_args)
cache.pop(keyargs, None)
wrapped.invalidate = invalidate
wrapped.prefill = prefill
return wrapped return wrapped
return wrap return wrap

View file

@ -17,7 +17,39 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import cached from synapse.storage._base import Cache, cached
class CacheTestCase(unittest.TestCase):
def setUp(self):
self.cache = Cache("test")
def test_empty(self):
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
self.cache.prefill("foo", 123)
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill("foo", 123)
self.cache.invalidate("foo")
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
class CacheDecoratorTestCase(unittest.TestCase): class CacheDecoratorTestCase(unittest.TestCase):