mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 13:01:34 +01:00
Implement and use an @lru_cache decorator (#8595)
We don't always need the full power of a DeferredCache.
This commit is contained in:
parent
fd7c743445
commit
cbc82aa09f
4 changed files with 272 additions and 61 deletions
1
changelog.d/8595.misc
Normal file
1
changelog.d/8595.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Implement and use an @lru_cache decorator.
|
|
@ -15,8 +15,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||||
|
@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
|
||||||
from synapse.state import POWER_KEY
|
from synapse.state import POWER_KEY
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import register_cache
|
from synapse.util.caches import register_cache
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import lru_cache
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
|
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
|
||||||
dict of user_id -> push_rules
|
dict of user_id -> push_rules
|
||||||
"""
|
"""
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
rules_for_room = await self._get_rules_for_room(room_id)
|
rules_for_room = self._get_rules_for_room(room_id)
|
||||||
|
|
||||||
rules_by_user = await rules_for_room.get_rules(event, context)
|
rules_by_user = await rules_for_room.get_rules(event, context)
|
||||||
|
|
||||||
|
@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
|
||||||
|
|
||||||
return rules_by_user
|
return rules_by_user
|
||||||
|
|
||||||
@cached()
|
@lru_cache()
|
||||||
def _get_rules_for_room(self, room_id):
|
def _get_rules_for_room(self, room_id):
|
||||||
"""Get the current RulesForRoom object for the given room id
|
"""Get the current RulesForRoom object for the given room id
|
||||||
|
|
||||||
|
@ -275,12 +276,14 @@ class RulesForRoom:
|
||||||
the entire cache for the room.
|
the entire cache for the room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
|
def __init__(
|
||||||
|
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer)
|
hs (HomeServer)
|
||||||
room_id (str)
|
room_id (str)
|
||||||
rules_for_room_cache(Cache): The cache object that caches these
|
rules_for_room_cache: The cache object that caches these
|
||||||
RoomsForUser objects.
|
RoomsForUser objects.
|
||||||
room_push_rule_cache_metrics (CacheMetric)
|
room_push_rule_cache_metrics (CacheMetric)
|
||||||
"""
|
"""
|
||||||
|
@ -489,13 +492,21 @@ class RulesForRoom:
|
||||||
self.state_group = state_group
|
self.state_group = state_group
|
||||||
|
|
||||||
|
|
||||||
class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
|
@attr.attrs(slots=True, frozen=True)
|
||||||
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
|
class _Invalidation:
|
||||||
# which namedtuple does for us (i.e. two _CacheContext are the same if
|
# _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
|
||||||
# their caches and keys match). This is important in particular to
|
# which means that it it is stored on the bulk_get_push_rules cache entry. In order
|
||||||
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
# to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
|
||||||
# of callbacks would grow.
|
# we need to ensure that two _Invalidation objects are "equal" if they refer to the
|
||||||
|
# same `cache` and `room_id`.
|
||||||
|
#
|
||||||
|
# attrs provides suitable __hash__ and __eq__ methods, provided we remember to
|
||||||
|
# set `frozen=True`.
|
||||||
|
|
||||||
|
cache = attr.ib(type=LruCache)
|
||||||
|
room_id = attr.ib(type=str)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
|
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||||
if rules:
|
if rules:
|
||||||
rules.invalidate_all()
|
rules.invalidate_all()
|
||||||
|
|
|
@ -13,10 +13,23 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -24,6 +37,7 @@ from twisted.internet import defer
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
|
||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase:
|
class _CacheDescriptorBase:
|
||||||
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
|
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
arg_spec = inspect.getfullargspec(orig)
|
arg_spec = inspect.getfullargspec(orig)
|
||||||
|
@ -97,8 +111,107 @@ class _CacheDescriptorBase:
|
||||||
|
|
||||||
self.add_cache_context = cache_context
|
self.add_cache_context = cache_context
|
||||||
|
|
||||||
|
self.cache_key_builder = get_cache_key_builder(
|
||||||
|
self.arg_names, self.arg_defaults
|
||||||
|
)
|
||||||
|
|
||||||
class CacheDescriptor(_CacheDescriptorBase):
|
|
||||||
|
class _LruCachedFunction(Generic[F]):
|
||||||
|
cache = None # type: LruCache[CacheKey, Any]
|
||||||
|
__call__ = None # type: F
|
||||||
|
|
||||||
|
|
||||||
|
def lru_cache(
|
||||||
|
max_entries: int = 1000, cache_context: bool = False,
|
||||||
|
) -> Callable[[F], _LruCachedFunction[F]]:
|
||||||
|
"""A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
This is more-or-less a drop-in equivalent to functools.lru_cache, although note
|
||||||
|
that the signature is slightly different.
|
||||||
|
|
||||||
|
The main differences with functools.lru_cache are:
|
||||||
|
(a) the size of the cache can be controlled via the cache_factor mechanism
|
||||||
|
(b) the wrapped function can request a "cache_context" which provides a
|
||||||
|
callback mechanism to indicate that the result is no longer valid
|
||||||
|
(c) prometheus metrics are exposed automatically.
|
||||||
|
|
||||||
|
The function should take zero or more arguments, which are used as the key for the
|
||||||
|
cache. Single-argument functions use that argument as the cache key; otherwise the
|
||||||
|
arguments are built into a tuple.
|
||||||
|
|
||||||
|
Cached functions can be "chained" (i.e. a cached function can call other cached
|
||||||
|
functions and get appropriately invalidated when they called caches are
|
||||||
|
invalidated) by adding a special "cache_context" argument to the function
|
||||||
|
and passing that as a kwarg to all caches called. For example:
|
||||||
|
|
||||||
|
@lru_cache(cache_context=True)
|
||||||
|
def foo(self, key, cache_context):
|
||||||
|
r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||||
|
r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||||
|
return r1 + r2
|
||||||
|
|
||||||
|
The wrapped function also has a 'cache' property which offers direct access to the
|
||||||
|
underlying LruCache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def func(orig: F) -> _LruCachedFunction[F]:
|
||||||
|
desc = LruCacheDescriptor(
|
||||||
|
orig, max_entries=max_entries, cache_context=cache_context,
|
||||||
|
)
|
||||||
|
return cast(_LruCachedFunction[F], desc)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
|
"""Helper for @lru_cache"""
|
||||||
|
|
||||||
|
class _Sentinel(enum.Enum):
|
||||||
|
sentinel = object()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, orig, max_entries: int = 1000, cache_context: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||||
|
self.max_entries = max_entries
|
||||||
|
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
cache = LruCache(
|
||||||
|
cache_name=self.orig.__name__, max_size=self.max_entries,
|
||||||
|
) # type: LruCache[CacheKey, Any]
|
||||||
|
|
||||||
|
get_cache_key = self.cache_key_builder
|
||||||
|
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||||
|
|
||||||
|
@functools.wraps(self.orig)
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||||
|
|
||||||
|
cache_key = get_cache_key(args, kwargs)
|
||||||
|
|
||||||
|
ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
|
||||||
|
if ret != sentinel:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
|
# has asked for one
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
||||||
|
|
||||||
|
ret2 = self.orig(obj, *args, **kwargs)
|
||||||
|
cache.set(cache_key, ret2, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ret2
|
||||||
|
|
||||||
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
|
wrapped.cache = cache
|
||||||
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
This caches deferreds, rather than the results themselves. Deferreds that
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
cache_context=False,
|
cache_context=False,
|
||||||
iterable=False,
|
iterable=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
|
@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[CacheKey, Any]
|
) # type: DeferredCache[CacheKey, Any]
|
||||||
|
|
||||||
def get_cache_key_gen(args, kwargs):
|
get_cache_key = self.cache_key_builder
|
||||||
"""Given some args/kwargs return a generator that resolves into
|
|
||||||
the cache_key.
|
|
||||||
|
|
||||||
We loop through each arg name, looking up if its in the `kwargs`,
|
|
||||||
otherwise using the next argument in `args`. If there are no more
|
|
||||||
args then we try looking the arg name up in the defaults
|
|
||||||
"""
|
|
||||||
pos = 0
|
|
||||||
for nm in self.arg_names:
|
|
||||||
if nm in kwargs:
|
|
||||||
yield kwargs[nm]
|
|
||||||
elif pos < len(args):
|
|
||||||
yield args[pos]
|
|
||||||
pos += 1
|
|
||||||
else:
|
|
||||||
yield self.arg_defaults[nm]
|
|
||||||
|
|
||||||
# By default our cache key is a tuple, but if there is only one item
|
|
||||||
# then don't bother wrapping in a tuple. This is to save memory.
|
|
||||||
if self.num_args == 1:
|
|
||||||
nm = self.arg_names[0]
|
|
||||||
|
|
||||||
def get_cache_key(args, kwargs):
|
|
||||||
if nm in kwargs:
|
|
||||||
return kwargs[nm]
|
|
||||||
elif len(args):
|
|
||||||
return args[0]
|
|
||||||
else:
|
|
||||||
return self.arg_defaults[nm]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def get_cache_key(args, kwargs):
|
|
||||||
return tuple(get_cache_key_gen(args, kwargs))
|
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def _wrapped(*args, **kwargs):
|
def _wrapped(*args, **kwargs):
|
||||||
|
@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
|
||||||
else:
|
else:
|
||||||
wrapped.invalidate = cache.invalidate
|
wrapped.invalidate = cache.invalidate
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
|
||||||
wrapped.invalidate_many = cache.invalidate_many
|
wrapped.invalidate_many = cache.invalidate_many
|
||||||
wrapped.prefill = cache.prefill
|
wrapped.prefill = cache.prefill
|
||||||
|
|
||||||
|
@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class CacheListDescriptor(_CacheDescriptorBase):
|
class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
"""Wraps an existing cache to support bulk fetching of keys.
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
Given a list of keys it looks in the cache to find any hits, then passes
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
|
@ -382,11 +459,13 @@ class _CacheContext:
|
||||||
on a lower level.
|
on a lower level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Cache = Union[DeferredCache, LruCache]
|
||||||
|
|
||||||
_cache_context_objects = (
|
_cache_context_objects = (
|
||||||
WeakValueDictionary()
|
WeakValueDictionary()
|
||||||
) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
|
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
|
||||||
|
|
||||||
def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
|
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._cache_key = cache_key
|
self._cache_key = cache_key
|
||||||
|
|
||||||
|
@ -396,8 +475,8 @@ class _CacheContext:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(
|
def get_instance(
|
||||||
cls, cache, cache_key
|
cls, cache: "_CacheContext.Cache", cache_key: CacheKey
|
||||||
): # type: (DeferredCache, CacheKey) -> _CacheContext
|
) -> "_CacheContext":
|
||||||
"""Returns an instance constructed with the given arguments.
|
"""Returns an instance constructed with the given arguments.
|
||||||
|
|
||||||
A new instance is only created if none already exists.
|
A new instance is only created if none already exists.
|
||||||
|
@ -418,7 +497,7 @@ def cached(
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
iterable: bool = False,
|
iterable: bool = False,
|
||||||
) -> Callable[[F], _CachedFunction[F]]:
|
) -> Callable[[F], _CachedFunction[F]]:
|
||||||
func = lambda orig: CacheDescriptor(
|
func = lambda orig: DeferredCacheDescriptor(
|
||||||
orig,
|
orig,
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
num_args=num_args,
|
num_args=num_args,
|
||||||
|
@ -460,7 +539,7 @@ def cachedList(
|
||||||
def batch_do_something(self, first_arg, second_args):
|
def batch_do_something(self, first_arg, second_args):
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
func = lambda orig: CacheListDescriptor(
|
func = lambda orig: DeferredCacheListDescriptor(
|
||||||
orig,
|
orig,
|
||||||
cached_method_name=cached_method_name,
|
cached_method_name=cached_method_name,
|
||||||
list_name=list_name,
|
list_name=list_name,
|
||||||
|
@ -468,3 +547,65 @@ def cachedList(
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_key_builder(
|
||||||
|
param_names: Sequence[str], param_defaults: Mapping[str, Any]
|
||||||
|
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
|
||||||
|
"""Construct a function which will build cache keys suitable for a cached function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_names: list of formal parameter names for the cached function
|
||||||
|
param_defaults: a mapping from parameter name to default value for that param
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function which will take an (args, kwargs) pair and return a cache key
|
||||||
|
"""
|
||||||
|
|
||||||
|
# By default our cache key is a tuple, but if there is only one item
|
||||||
|
# then don't bother wrapping in a tuple. This is to save memory.
|
||||||
|
|
||||||
|
if len(param_names) == 1:
|
||||||
|
nm = param_names[0]
|
||||||
|
|
||||||
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
|
if nm in kwargs:
|
||||||
|
return kwargs[nm]
|
||||||
|
elif len(args):
|
||||||
|
return args[0]
|
||||||
|
else:
|
||||||
|
return param_defaults[nm]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
|
||||||
|
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
|
||||||
|
|
||||||
|
return get_cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_key_gen(
|
||||||
|
param_names: Iterable[str],
|
||||||
|
param_defaults: Mapping[str, Any],
|
||||||
|
args: Sequence[Any],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
the cache_key.
|
||||||
|
|
||||||
|
This is essentially the same operation as `inspect.getcallargs`, but optimised so
|
||||||
|
that we don't need to inspect the target function for each call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We loop through each arg name, looking up if its in the `kwargs`,
|
||||||
|
# otherwise using the next argument in `args`. If there are no more
|
||||||
|
# args then we try looking the arg name up in the defaults.
|
||||||
|
pos = 0
|
||||||
|
for nm in param_names:
|
||||||
|
if nm in kwargs:
|
||||||
|
yield kwargs[nm]
|
||||||
|
elif pos < len(args):
|
||||||
|
yield args[pos]
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
yield param_defaults[nm]
|
||||||
|
|
|
@ -29,13 +29,46 @@ from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
)
|
)
|
||||||
from synapse.util.caches import descriptors
|
from synapse.util.caches import descriptors
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, lru_cache
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import get_awaitable_result
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LruCacheDecoratorTestCase(unittest.TestCase):
|
||||||
|
def test_base(self):
|
||||||
|
class Cls:
|
||||||
|
def __init__(self):
|
||||||
|
self.mock = mock.Mock()
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def fn(self, arg1, arg2):
|
||||||
|
return self.mock(arg1, arg2)
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
obj.mock.return_value = "fish"
|
||||||
|
r = obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
obj.mock.assert_called_once_with(1, 2)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# a call with different params should call the mock again
|
||||||
|
obj.mock.return_value = "chips"
|
||||||
|
r = obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, "chips")
|
||||||
|
obj.mock.assert_called_once_with(1, 3)
|
||||||
|
obj.mock.reset_mock()
|
||||||
|
|
||||||
|
# the two values should now be cached
|
||||||
|
r = obj.fn(1, 2)
|
||||||
|
self.assertEqual(r, "fish")
|
||||||
|
r = obj.fn(1, 3)
|
||||||
|
self.assertEqual(r, "chips")
|
||||||
|
obj.mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def run_on_reactor():
|
def run_on_reactor():
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
reactor.callLater(0, d.callback, 0)
|
reactor.callLater(0, d.callback, 0)
|
||||||
|
@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
d = obj.fn(1)
|
d = obj.fn(1)
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
def test_invalidate_cascade(self):
|
||||||
|
"""Invalidations should cascade up through cache contexts"""
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
@cached(cache_context=True)
|
||||||
|
async def func1(self, key, cache_context):
|
||||||
|
return await self.func2(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
async def func2(self, key, cache_context):
|
||||||
|
return self.func3(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
@lru_cache(cache_context=True)
|
||||||
|
def func3(self, key, cache_context):
|
||||||
|
self.invalidate = cache_context.invalidate
|
||||||
|
return 42
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
|
||||||
|
top_invalidate = mock.Mock()
|
||||||
|
r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
|
||||||
|
self.assertEqual(r, 42)
|
||||||
|
obj.invalidate()
|
||||||
|
top_invalidate.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||||
"""More tests for @cached
|
"""More tests for @cached
|
||||||
|
|
Loading…
Reference in a new issue