mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-05 00:14:35 +01:00
Sequence the modifications to the cache so that selects don't race with inserts
This commit is contained in:
parent
d9cc5de9e5
commit
261d809a47
1 changed files with 23 additions and 3 deletions
|
@ -31,6 +31,7 @@ import functools
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -68,9 +69,20 @@ class Cache(object):
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.keylen = keylen
|
self.keylen = keylen
|
||||||
|
self.sequence = 0
|
||||||
|
self.thread = None
|
||||||
caches_by_name[name] = self.cache
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def check_thread(self):
|
||||||
|
expected_thread = self.thread
|
||||||
|
if expected_thread is None:
|
||||||
|
self.thread = threading.current_thread()
|
||||||
|
else:
|
||||||
|
if expected_thread is not threading.current_thread():
|
||||||
|
raise ValueError(
|
||||||
|
"Cache objects can only be accessed from the main thread"
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, *keyargs):
|
def get(self, *keyargs):
|
||||||
if len(keyargs) != self.keylen:
|
if len(keyargs) != self.keylen:
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
@ -82,6 +94,11 @@ class Cache(object):
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
|
|
||||||
|
def update(self, sequence, *args):
|
||||||
|
self.check_thread()
|
||||||
|
if self.sequence == sequence:
|
||||||
|
self.prefill(*args)
|
||||||
|
|
||||||
def prefill(self, *args): # because I can't *keyargs, value
|
def prefill(self, *args): # because I can't *keyargs, value
|
||||||
keyargs = args[:-1]
|
keyargs = args[:-1]
|
||||||
value = args[-1]
|
value = args[-1]
|
||||||
|
@ -96,9 +113,10 @@ class Cache(object):
|
||||||
self.cache[keyargs] = value
|
self.cache[keyargs] = value
|
||||||
|
|
||||||
def invalidate(self, *keyargs):
|
def invalidate(self, *keyargs):
|
||||||
|
self.check_thread()
|
||||||
if len(keyargs) != self.keylen:
|
if len(keyargs) != self.keylen:
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||||
|
self.sequence += 1
|
||||||
self.cache.pop(keyargs, None)
|
self.cache.pop(keyargs, None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False):
|
||||||
try:
|
try:
|
||||||
defer.returnValue(cache.get(*keyargs))
|
defer.returnValue(cache.get(*keyargs))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = yield orig(self, *keyargs)
|
ret = yield orig(self, *keyargs)
|
||||||
|
|
||||||
cache.prefill(*keyargs + (ret,))
|
cache.update(sequence, *keyargs + (ret,))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue