Change LRUCache to be tree-based so we can delete subtrees.

This commit is contained in:
David Baker 2016-01-21 19:16:25 +00:00
parent 297eded261
commit f1f8122120
7 changed files with 140 additions and 52 deletions

View file

@ -309,14 +309,14 @@ def _flatten_dict(d, prefix=[], result={}):
return result return result
regex_cache = LruCache(5000) regex_cache = LruCache(5000, 1)
def _compile_regex(regex_str): def _compile_regex(regex_str):
r = regex_cache.get(regex_str, None) r = regex_cache.get((regex_str,), None)
if r: if r:
return r return r
r = re.compile(regex_str, flags=re.IGNORECASE) r = re.compile(regex_str, flags=re.IGNORECASE)
regex_cache[regex_str] = r regex_cache[(regex_str,)] = r
return r return r

View file

@ -38,7 +38,7 @@ class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True): def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru: if lru:
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries, keylen=keylen)
self.max_entries = None self.max_entries = None
else: else:
self.cache = OrderedDict() self.cache = OrderedDict()
@ -99,6 +99,15 @@ class Cache(object):
self.sequence += 1 self.sequence += 1
self.cache.pop(key, None) self.cache.pop(key, None)
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
self.sequence += 1
self.cache.del_multi(key)
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1 self.sequence += 1

View file

@ -32,7 +32,7 @@ class DictionaryCache(object):
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries, keylen=1)
self.name = name self.name = name
self.sequence = 0 self.sequence = 0
@ -56,7 +56,7 @@ class DictionaryCache(object):
) )
def get(self, key, dict_keys=None): def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel) entry = self.cache.get((key,), self.sentinel)
if entry is not self.sentinel: if entry is not self.sentinel:
cache_counter.inc_hits(self.name) cache_counter.inc_hits(self.name)
@ -78,7 +78,7 @@ class DictionaryCache(object):
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
self.cache.pop(key, None) self.cache.pop((key,), None)
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
@ -96,8 +96,8 @@ class DictionaryCache(object):
self._update_or_insert(key, value) self._update_or_insert(key, value)
def _update_or_insert(self, key, value): def _update_or_insert(self, key, value):
entry = self.cache.setdefault(key, DictionaryEntry(False, {})) entry = self.cache.setdefault((key,), DictionaryEntry(False, {}))
entry.value.update(value) entry.value.update(value)
def _insert(self, key, value): def _insert(self, key, value):
self.cache[key] = DictionaryEntry(True, value) self.cache[(key,)] = DictionaryEntry(True, value)

View file

@ -17,11 +17,23 @@
from functools import wraps from functools import wraps
import threading import threading
from synapse.util.caches.treecache import TreeCache
def enumerate_leaves(node, depth):
if depth == 0:
yield node
else:
for n in node.values():
for m in enumerate_leaves(n, depth - 1):
yield m
class LruCache(object): class LruCache(object):
"""Least-recently-used cache.""" """Least-recently-used cache."""
def __init__(self, max_size): def __init__(self, max_size, keylen):
cache = {} cache = TreeCache()
self.size = 0
list_root = [] list_root = []
list_root[:] = [list_root, list_root, None, None] list_root[:] = [list_root, list_root, None, None]
@ -44,6 +56,7 @@ class LruCache(object):
prev_node[NEXT] = node prev_node[NEXT] = node
next_node[PREV] = node next_node[PREV] = node
cache[key] = node cache[key] = node
self.size += 1
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node[PREV] prev_node = node[PREV]
@ -62,7 +75,7 @@ class LruCache(object):
next_node = node[NEXT] next_node = node[NEXT]
prev_node[NEXT] = next_node prev_node[NEXT] = next_node
next_node[PREV] = prev_node next_node[PREV] = prev_node
cache.pop(node[KEY], None) self.size -= 1
@synchronized @synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
@ -81,8 +94,10 @@ class LruCache(object):
node[VALUE] = value node[VALUE] = value
else: else:
add_node(key, value) add_node(key, value)
if len(cache) > max_size: if self.size > max_size:
delete_node(list_root[PREV]) todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
@ -91,8 +106,10 @@ class LruCache(object):
return node[VALUE] return node[VALUE]
else: else:
add_node(key, value) add_node(key, value)
if len(cache) > max_size: if self.size > max_size:
delete_node(list_root[PREV]) todelete = list_root[PREV]
delete_node(todelete)
cache.pop(todelete[KEY], None)
return value return value
@synchronized @synchronized
@ -100,10 +117,19 @@ class LruCache(object):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
delete_node(node) delete_node(node)
cache.pop(node[KEY], None)
return node[VALUE] return node[VALUE]
else: else:
return default return default
@synchronized
def cache_del_multi(key):
popped = cache.pop(key)
if popped is None:
return
for leaf in enumerate_leaves(popped, keylen - len(key)):
delete_node(leaf)
@synchronized @synchronized
def cache_clear(): def cache_clear():
list_root[NEXT] = list_root list_root[NEXT] = list_root
@ -112,7 +138,7 @@ class LruCache(object):
@synchronized @synchronized
def cache_len(): def cache_len():
return len(cache) return self.size
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):
@ -123,6 +149,7 @@ class LruCache(object):
self.set = cache_set self.set = cache_set
self.setdefault = cache_set_default self.setdefault = cache_set_default
self.pop = cache_pop self.pop = cache_pop
self.del_multi = cache_del_multi
self.len = cache_len self.len = cache_len
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear self.clear = cache_clear

View file

@ -0,0 +1,52 @@
SENTINEL = object()
class TreeCache(object):
def __init__(self):
self.root = {}
def __setitem__(self, key, value):
return self.set(key, value)
def set(self, key, value):
node = self.root
for k in key[:-1]:
node = node.setdefault(k, {})
node[key[-1]] = value
def get(self, key, default=None):
node = self.root
for k in key[:-1]:
node = node.get(k, None)
if node is None:
return default
return node.get(key[-1], default)
def clear(self):
self.root = {}
def pop(self, key, default=None):
nodes = []
node = self.root
for k in key[:-1]:
node = node.get(k, None)
nodes.append(node) # don't add the root node
if node is None:
return default
popped = node.pop(key[-1], SENTINEL)
if popped is SENTINEL:
return default
node_and_keys = zip(nodes, key)
node_and_keys.reverse()
node_and_keys.append((self.root, None))
for i in range(len(node_and_keys) - 1):
n,k = node_and_keys[i]
if n:
break
node_and_keys[i+1][0].pop(k)
return popped

View file

@ -56,42 +56,42 @@ class CacheTestCase(unittest.TestCase):
def test_eviction(self): def test_eviction(self):
cache = Cache("test", max_entries=2) cache = Cache("test", max_entries=2)
cache.prefill(1, "one") cache.prefill((1,), "one")
cache.prefill(2, "two") cache.prefill((2,), "two")
cache.prefill(3, "three") # 1 will be evicted cache.prefill((3,), "three") # 1 will be evicted
failed = False failed = False
try: try:
cache.get(1) cache.get((1,))
except KeyError: except KeyError:
failed = True failed = True
self.assertTrue(failed) self.assertTrue(failed)
cache.get(2) cache.get((2,))
cache.get(3) cache.get((3,))
def test_eviction_lru(self): def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True) cache = Cache("test", max_entries=2, lru=True)
cache.prefill(1, "one") cache.prefill((1,), "one")
cache.prefill(2, "two") cache.prefill((2,), "two")
# Now access 1 again, thus causing 2 to be least-recently used # Now access 1 again, thus causing 2 to be least-recently used
cache.get(1) cache.get((1,))
cache.prefill(3, "three") cache.prefill((3,), "three")
failed = False failed = False
try: try:
cache.get(2) cache.get((2,))
except KeyError: except KeyError:
failed = True failed = True
self.assertTrue(failed) self.assertTrue(failed)
cache.get(1) cache.get((1,))
cache.get(3) cache.get((3,))
class CacheDecoratorTestCase(unittest.TestCase): class CacheDecoratorTestCase(unittest.TestCase):

View file

@ -21,34 +21,34 @@ from synapse.util.caches.lrucache import LruCache
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
def test_get_set(self): def test_get_set(self):
cache = LruCache(1) cache = LruCache(1, 1)
cache["key"] = "value" cache[("key",)] = "value"
self.assertEquals(cache.get("key"), "value") self.assertEquals(cache.get(("key",)), "value")
self.assertEquals(cache["key"], "value") self.assertEquals(cache[("key",)], "value")
def test_eviction(self): def test_eviction(self):
cache = LruCache(2) cache = LruCache(2, 1)
cache[1] = 1 cache[(1,)] = 1
cache[2] = 2 cache[(2,)] = 2
self.assertEquals(cache.get(1), 1) self.assertEquals(cache.get((1,)), 1)
self.assertEquals(cache.get(2), 2) self.assertEquals(cache.get((2,)), 2)
cache[3] = 3 cache[(3,)] = 3
self.assertEquals(cache.get(1), None) self.assertEquals(cache.get((1,)), None)
self.assertEquals(cache.get(2), 2) self.assertEquals(cache.get((2,)), 2)
self.assertEquals(cache.get(3), 3) self.assertEquals(cache.get((3,)), 3)
def test_setdefault(self): def test_setdefault(self):
cache = LruCache(1) cache = LruCache(1, 1)
self.assertEquals(cache.setdefault("key", 1), 1) self.assertEquals(cache.setdefault(("key",), 1), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get(("key",)), 1)
self.assertEquals(cache.setdefault("key", 2), 1) self.assertEquals(cache.setdefault(("key",), 2), 1)
self.assertEquals(cache.get("key"), 1) self.assertEquals(cache.get(("key",)), 1)
def test_pop(self): def test_pop(self):
cache = LruCache(1) cache = LruCache(1, 1)
cache["key"] = 1 cache[("key",)] = 1
self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop(("key",)), 1)
self.assertEquals(cache.pop("key"), None) self.assertEquals(cache.pop(("key",)), None)