Make LruCache thread safe, as its used for event cache

This commit is contained in:
Erik Johnston 2015-04-15 16:07:59 +01:00
parent a5c72780e6
commit 806f380a8b

View file

@ -14,6 +14,10 @@
# limitations under the License. # limitations under the License.
from functools import wraps
import threading
class LruCache(object): class LruCache(object):
"""Least-recently-used cache.""" """Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety. # TODO(mjark) Add mutex for linked list for thread safety.
@ -24,6 +28,16 @@ class LruCache(object):
PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
lock = threading.Lock()
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
def add_node(key, value): def add_node(key, value):
prev_node = list_root prev_node = list_root
next_node = prev_node[NEXT] next_node = prev_node[NEXT]
@ -51,6 +65,7 @@ class LruCache(object):
next_node[PREV] = prev_node next_node[PREV] = prev_node
cache.pop(node[KEY], None) cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -59,6 +74,7 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_set(key, value): def cache_set(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -69,6 +85,7 @@ class LruCache(object):
if len(cache) > max_size: if len(cache) > max_size:
delete_node(list_root[PREV]) delete_node(list_root[PREV])
@synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -79,6 +96,7 @@ class LruCache(object):
delete_node(list_root[PREV]) delete_node(list_root[PREV])
return value return value
@synchronized
def cache_pop(key, default=None): def cache_pop(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
@ -87,9 +105,11 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_len(): def cache_len():
return len(cache) return len(cache)
@synchronized
def cache_contains(key): def cache_contains(key):
return key in cache return key in cache