Implement a CacheListDescriptor

This commit is contained in:
Erik Johnston 2015-08-07 18:14:49 +01:00
parent ffdb8c3828
commit 0211890134

View file

@ -16,6 +16,7 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -231,6 +232,101 @@ class CacheDescriptor(object):
return wrapped return wrapped
class CacheListDescriptor(object):
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r[arg], arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True): def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
@ -250,6 +346,16 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
) )
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()