mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 13:03:50 +01:00
Pull out the cache logic from the @cached wrapper into its own class we can reuse
This commit is contained in:
parent
b1022ed8b5
commit
0f86312c4c
2 changed files with 87 additions and 36 deletions
|
@ -53,6 +53,47 @@ cache_counter = metrics.register_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(object):
|
||||||
|
|
||||||
|
def __init__(self, name, max_entries=1000, keylen=1):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.name = name
|
||||||
|
self.keylen = keylen
|
||||||
|
|
||||||
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def get(self, *keyargs):
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
if keyargs in self.cache:
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
return self.cache[keyargs]
|
||||||
|
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
raise KeyError()
|
||||||
|
|
||||||
|
def prefill(self, *args): # because I can't *keyargs, value
|
||||||
|
keyargs = args[:-1]
|
||||||
|
value = args[-1]
|
||||||
|
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
while len(self.cache) > self.max_entries:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
|
self.cache[keyargs] = value
|
||||||
|
|
||||||
|
def invalidate(self, *keyargs):
|
||||||
|
if len(keyargs) != self.keylen:
|
||||||
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
|
||||||
|
self.cache.pop(keyargs, None)
|
||||||
|
|
||||||
|
|
||||||
# TODO(paul):
|
# TODO(paul):
|
||||||
# * consider other eviction strategies - LRU?
|
# * consider other eviction strategies - LRU?
|
||||||
def cached(max_entries=1000, num_args=1):
|
def cached(max_entries=1000, num_args=1):
|
||||||
|
@ -70,48 +111,26 @@ def cached(max_entries=1000, num_args=1):
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
"""
|
"""
|
||||||
def wrap(orig):
|
def wrap(orig):
|
||||||
cache = OrderedDict()
|
cache = Cache(
|
||||||
name = orig.__name__
|
name=orig.__name__,
|
||||||
|
max_entries=max_entries,
|
||||||
caches_by_name[name] = cache
|
keylen=num_args,
|
||||||
|
)
|
||||||
def prefill(*args): # because I can't *keyargs, value
|
|
||||||
keyargs = args[:-1]
|
|
||||||
value = args[-1]
|
|
||||||
|
|
||||||
if len(keyargs) != num_args:
|
|
||||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
|
||||||
|
|
||||||
while len(cache) > max_entries:
|
|
||||||
cache.popitem(last=False)
|
|
||||||
|
|
||||||
cache[keyargs] = value
|
|
||||||
|
|
||||||
@functools.wraps(orig)
|
@functools.wraps(orig)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(self, *keyargs):
|
def wrapped(self, *keyargs):
|
||||||
if len(keyargs) != num_args:
|
try:
|
||||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
defer.returnValue(cache.get(*keyargs))
|
||||||
|
except KeyError:
|
||||||
if keyargs in cache:
|
|
||||||
cache_counter.inc_hits(name)
|
|
||||||
defer.returnValue(cache[keyargs])
|
|
||||||
|
|
||||||
cache_counter.inc_misses(name)
|
|
||||||
ret = yield orig(self, *keyargs)
|
ret = yield orig(self, *keyargs)
|
||||||
|
|
||||||
prefill(*keyargs + (ret,))
|
cache.prefill(*keyargs + (ret,))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def invalidate(*keyargs):
|
wrapped.invalidate = cache.invalidate
|
||||||
if len(keyargs) != num_args:
|
wrapped.prefill = cache.prefill
|
||||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
|
||||||
|
|
||||||
cache.pop(keyargs, None)
|
|
||||||
|
|
||||||
wrapped.invalidate = invalidate
|
|
||||||
wrapped.prefill = prefill
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
|
|
@ -17,7 +17,39 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.storage._base import cached
|
from synapse.storage._base import Cache, cached
|
||||||
|
|
||||||
|
|
||||||
|
class CacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.cache = Cache("test")
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
self.cache.get("foo")
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
def test_hit(self):
|
||||||
|
self.cache.prefill("foo", 123)
|
||||||
|
|
||||||
|
self.assertEquals(self.cache.get("foo"), 123)
|
||||||
|
|
||||||
|
def test_invalidate(self):
|
||||||
|
self.cache.prefill("foo", 123)
|
||||||
|
self.cache.invalidate("foo")
|
||||||
|
|
||||||
|
failed = False
|
||||||
|
try:
|
||||||
|
self.cache.get("foo")
|
||||||
|
except KeyError:
|
||||||
|
failed = True
|
||||||
|
|
||||||
|
self.assertTrue(failed)
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.TestCase):
|
class CacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in a new issue