0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-02 00:33:57 +01:00

Merge pull request #8548 from matrix-org/rav/deferred_cache

Rename Cache to DeferredCache, and related changes
This commit is contained in:
Richard van der Hoff 2020-10-15 11:42:07 +01:00 committed by GitHub
commit 0a08cd1065
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 454 additions and 374 deletions

1
changelog.d/8548.misc Normal file
View file

@ -0,0 +1 @@
Rename `Cache` to `DeferredCache`, to better reflect its purpose.

View file

@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache( self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) ) # type: DeferredCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View file

@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore): class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) )

View file

@ -34,7 +34,8 @@ from synapse.storage.database import (
) )
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -1004,7 +1005,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache = Cache( self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000 name="device_id_exists", keylen=2, max_entries=10000
) )

View file

@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -145,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore):
self._cleanup_old_transaction_ids, self._cleanup_old_transaction_ids,
) )
self._get_event_cache = Cache( self._get_event_cache = DeferredCache(
"*getEvent*", "*getEvent*",
keylen=3, keylen=3,
max_entries=hs.config.caches.event_cache_size, max_entries=hs.config.caches.event_cache_size,

View file

@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
KT = TypeVar("KT")
VT = TypeVar("VT")
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
"""
cache_type = TreeCache if tree else dict
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _Sentinel.sentinel:
raise KeyError()
else:
return default
def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()

View file

@ -13,25 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import inspect import inspect
import logging import logging
import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]):
__call__ = None # type: F __call__ = None # type: F
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
_CacheSentinel = object()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache:
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
Returns:
Cache
"""
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
else:
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class _CacheDescriptorBase: class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False): def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
self.orig = orig self.orig = orig
@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
def __get__(self, obj, owner): def __get__(self, obj, owner):
cache = Cache( cache = DeferredCache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
) ) # type: DeferredCache[Tuple, Any]
def get_cache_key_gen(args, kwargs): def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into """Given some args/kwargs return a generator that resolves into
@ -640,9 +400,9 @@ class _CacheContext:
_cache_context_objects = ( _cache_context_objects = (
WeakValueDictionary() WeakValueDictionary()
) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
self._cache = cache self._cache = cache
self._cache_key = cache_key self._cache_key = cache_key
@ -651,7 +411,9 @@ class _CacheContext:
self._cache.invalidate(self._cache_key) self._cache.invalidate(self._cache_key)
@classmethod @classmethod
def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext def get_instance(
cls, cache, cache_key
): # type: (DeferredCache, CacheKey) -> _CacheContext
"""Returns an instance constructed with the given arguments. """Returns an instance constructed with the given arguments.
A new instance is only created if none already exists. A new instance is only created if none already exists.

View file

@ -64,7 +64,8 @@ class LruCache:
Args: Args:
max_size: The maximum amount of entries the cache can hold max_size: The maximum amount of entries the cache can hold
keylen: The length of the tuple used as the cache key keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
cache_type (type): cache_type (type):
type of underlying cache to be used. Typically one of dict type of underlying cache to be used. Typically one of dict

View file

@ -20,82 +20,11 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import cached
from tests import unittest from tests import unittest
class CacheTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
self.cache.prefill("foo", 123)
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill(("foo",), 123)
self.cache.invalidate(("foo",))
failed = False
try:
self.cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_eviction(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)
class CacheDecoratorTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest from tests import unittest
@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped. Caches produce metrics reflecting their state when scraped.
""" """
CACHE_NAME = "cache_metrics_test_fgjkbdfg" CACHE_NAME = "cache_metrics_test_fgjkbdfg"
cache = Cache(CACHE_NAME, max_entries=777) cache = DeferredCache(CACHE_NAME, max_entries=777)
items = { items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")

View file

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
class DeferredCacheTestCase(unittest.TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
try:
cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
cache = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123)
def test_invalidate(self):
cache = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
failed = False
try:
cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_invalidate_all(self):
cache = DeferredCache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
def test_eviction(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)

View file

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from functools import partial
import mock import mock
@ -42,49 +41,6 @@ def run_on_reactor():
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self):