0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-16 00:03:49 +01:00

Merge pull request #1815 from matrix-org/erikj/iter_cache_size

Optionally measure size of cache by sum of length of values
This commit is contained in:
Erik Johnston 2017-01-17 11:51:09 +00:00 committed by GitHub
commit d11d7cdf87
12 changed files with 308 additions and 85 deletions

View file

@ -41,7 +41,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
@ -77,6 +77,9 @@ class _StateCacheEntry(object):
else: else:
self.state_id = _gen_state_id() self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -99,6 +102,7 @@ class StateHandler(object):
clock=self.clock, clock=self.clock,
max_len=SIZE_OF_CACHE, max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )

View file

@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
) )
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()

View file

@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids, room_id, state_group, state_ids,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None): cache_context, event=None):
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based

View file

@ -284,7 +284,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000) @cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()

View file

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
) )
@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object): class Cache(object):
__slots__ = ( __slots__ = (
"cache", "cache",
@ -51,12 +70,16 @@ class Cache(object):
"sequence", "sequence",
"thread", "thread",
"metrics", "metrics",
"_pending_deferred_cache",
) )
def __init__(self, name, max_entries=1000, keylen=1, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
) )
self.name = name self.name = name
@ -76,7 +99,15 @@ class Cache(object):
) )
def get(self, key, default=_CacheSentinel, callback=None): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback) callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -88,15 +119,39 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value, callback=None): def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
if self.sequence == sequence: entry = CacheEntry(
# Only update the cache if the caches sequence number matches the deferred=value,
# number that the cache had before the SELECT was started (SYN-369) sequence=self.sequence,
self.prefill(key, value, callback=callback) callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None): def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback) callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None) self.cache.pop(key, None)
def invalidate_many(self, key): def invalidate_many(self, key):
@ -119,6 +178,12 @@ class Cache(object):
self.sequence += 1 self.sequence += 1
self.cache.del_multi(key) self.cache.del_multi(key)
val = self._pending_deferred_cache.pop(key, None)
if val is not None:
entry_dict, _ = val
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1 self.sequence += 1
@ -155,7 +220,7 @@ class CacheDescriptor(object):
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False): inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -169,6 +234,8 @@ class CacheDescriptor(object):
self.num_args = num_args self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig) all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1] self.arg_names = all_args.args[1:num_args + 1]
@ -203,6 +270,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
@ -243,11 +311,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer) return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
self.function_to_call, self.function_to_call,
@ -261,7 +324,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback) cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -359,7 +422,6 @@ class CacheListDescriptor(object):
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -382,8 +444,8 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update( cache.set(
sequence, tuple(key), observer, tuple(key), observer,
callback=invalidate_callback callback=invalidate_callback
) )
@ -421,17 +483,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
tree=tree, tree=tree,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -439,6 +504,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )

View file

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object): class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name self.name = name
self.sequence = 0 self.sequence = 0

View file

@ -15,6 +15,7 @@
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from collections import OrderedDict
import logging import logging
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object): class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False): reset_expiry_on_get=False, iterable=False):
""" """
Args: Args:
cache_name (str): Name of this cache, used for logging. cache_name (str): Name of this cache, used for logging.
@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time. evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False. an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
""" """
self._cache_name = cache_name self._cache_name = cache_name
@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {} self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache) self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self): def start(self):
if not self._expiry_ms: if not self._expiry_ms:
@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec() now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value) self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items if self.iterable:
if self._max_len and len(self._cache.keys()) > self._max_len: self._size_estimate += len(value)
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
for k, _ in sorted_entries[self._max_len:]: # Evict if there are now too many items
self._cache.pop(k) while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key): def __getitem__(self, key):
try: try:
@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
return return
begin_length = len(self._cache) begin_length = len(self)
now = self._clock.time_msec() now = self._clock.time_msec()
@ -110,15 +116,20 @@ class ExpiringCache(object):
keys_to_delete.add(key) keys_to_delete.add(key)
for k in keys_to_delete: for k in keys_to_delete:
self._cache.pop(k) value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug( logger.debug(
"[%s] _prune_cache before: %d, after len: %d", "[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache) self._cache_name, begin_length, len(self)
) )
def __len__(self): def __len__(self):
return len(self._cache) if self.iterable:
return self._size_estimate
else:
return len(self._cache)
class _CacheEntry(object): class _CacheEntry(object):

View file

@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted. when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type() cache = cache_type()
self.cache = cache # Used for introspection. self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None) list_root = _Node(None, None, None, None)
@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock() lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f): def synchronized(f):
@wraps(f) @wraps(f)
def inner(*args, **kwargs): def inner(*args, **kwargs):
@ -66,6 +72,16 @@ class LruCache(object):
return inner return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node.prev_node prev_node = node.prev_node
next_node = node.next_node next_node = node.next_node
@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks.clear() node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None, callback=None): def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback: node.callbacks.update(callbacks)
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value, callback=None): def cache_set(key, value, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value: if value != node.value:
@ -116,21 +137,18 @@ class LruCache(object):
cb() cb()
node.callbacks.clear() node.callbacks.clear()
if callback: if size_callback:
node.callbacks.add(callback) cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
if callback: add_node(key, value, set(callbacks))
callbacks = set([callback])
else: evict()
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
@ -139,10 +157,7 @@ class LruCache(object):
return node.value return node.value
else: else:
add_node(key, value) add_node(key, value)
if len(cache) > max_size: evict()
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
return value return value
@synchronized @synchronized
@ -175,10 +190,6 @@ class LruCache(object):
cb() cb()
cache.clear() cache.clear()
@synchronized
def cache_len():
return len(cache)
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):
return key in cache return key in cache
@ -190,7 +201,7 @@ class LruCache(object):
self.pop = cache_pop self.pop = cache_pop
if cache_type is TreeCache: if cache_type is TreeCache:
self.del_multi = cache_del_multi self.del_multi = cache_del_multi
self.len = cache_len self.len = synchronized(cache_len)
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear self.clear = cache_clear

View file

@ -65,12 +65,27 @@ class TreeCache(object):
return popped return popped
def values(self): def values(self):
return [e.value for e in self.root.values()] return list(iterate_tree_cache_entry(self.root))
def __len__(self): def __len__(self):
return self.size return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object): class _Entry(object):
__slots__ = ["value"] __slots__ = ["value"]

View file

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0] callcount2 = [0]
class A(object): class A(object):
@cached(max_entries=2) @cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2) self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3") yield a.func("foo3")
self.assertEquals(callcount[0], 3) self.assertEquals(callcount[0], 3)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2017 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
class ExpiringCacheTestCase(unittest.TestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
cache["key"] = "value"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache["key"], "value")
def test_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2)
cache["key"] = "value"
cache["key2"] = "value2"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache.get("key2"), "value2")
cache["key3"] = "value3"
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), "value2")
self.assertEquals(cache.get("key3"), "value3")
def test_iterable_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache["key"] = [1]
cache["key2"] = [2, 3]
cache["key3"] = [4, 5]
self.assertEquals(cache.get("key"), [1])
self.assertEquals(cache.get("key2"), [2, 3])
self.assertEquals(cache.get("key3"), [4, 5])
cache["key4"] = [6, 7]
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache.get("key3"), [4, 5])
self.assertEquals(cache.get("key4"), [6, 7])
def test_time_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache.start()
cache["key"] = 1
clock.advance_time(0.5)
cache["key2"] = 2
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(0.9)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(1)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)

View file

@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", "value") cache.get("key", "value")
@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.set("key", "value2") cache.set("key", "value2")
@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock() m = Mock()
cache = LruCache(1) cache = LruCache(1)
cache.set("key", "value", m) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.set("key", "value") cache.set("key", "value")
@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock() m = Mock()
cache = LruCache(1) cache = LruCache(1)
cache.set("key", "value", m) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.pop("key") cache.pop("key")
@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock() m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache) cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1) cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", m2) cache.set(("a", "2"), "value", callbacks=[m2])
cache.set(("b", "1"), "value", m3) cache.set(("b", "1"), "value", callbacks=[m3])
cache.set(("b", "2"), "value", m4) cache.set(("b", "2"), "value", callbacks=[m4])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock() m2 = Mock()
cache = LruCache(5) cache = LruCache(5)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", m2) cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3") m3 = Mock(name="m3")
cache = LruCache(2) cache = LruCache(2)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", m2) cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0) self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3) cache.set("key3", "value", callbacks=[m3])
self.assertEquals(m1.call_count, 1) self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0) self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
self.assertEquals(m1.call_count, 1) self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1) self.assertEquals(m3.call_count, 1)
class LruCacheSizedTestCase(unittest.TestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
cache["key4"] = [4]
self.assertEquals(cache["key1"], [0])
self.assertEquals(cache["key2"], [1, 2])
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(len(cache), 5)
cache["key5"] = [5, 6]
self.assertEquals(len(cache), 4)
self.assertEquals(cache.get("key1"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(cache["key5"], [5, 6])