From 261d809a4779b03c81ada52ed3893b2ad8782a96 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 5 May 2015 14:08:03 +0100 Subject: [PATCH] Sequence the modifications to the cache so that selects don't race with inserts --- synapse/storage/_base.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c328b5274..7f5477dee 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -31,6 +31,7 @@ import functools import simplejson as json import sys import time +import threading logger = logging.getLogger(__name__) @@ -68,9 +69,20 @@ class Cache(object): self.name = name self.keylen = keylen - + self.sequence = 0 + self.thread = None caches_by_name[name] = self.cache + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + def get(self, *keyargs): if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) @@ -82,6 +94,11 @@ class Cache(object): cache_counter.inc_misses(self.name) raise KeyError() + def update(self, sequence, *args): + self.check_thread() + if self.sequence == sequence: + self.prefill(*args) + def prefill(self, *args): # because I can't *keyargs, value keyargs = args[:-1] value = args[-1] @@ -96,9 +113,10 @@ class Cache(object): self.cache[keyargs] = value def invalidate(self, *keyargs): + self.check_thread() if len(keyargs) != self.keylen: raise ValueError("Expected a key to have %d items", self.keylen) - + self.sequence += 1 self.cache.pop(keyargs, None) @@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False): try: defer.returnValue(cache.get(*keyargs)) except KeyError: + sequence = cache.sequence + ret = yield orig(self, *keyargs) - cache.prefill(*keyargs + (ret,)) + cache.update(sequence, *keyargs + (ret,)) defer.returnValue(ret)