mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-21 03:21:54 +01:00
Fix some error cases in the caching layer. (#5749)
There was some inconsistent behaviour in the caching layer around how exceptions were handled - particularly synchronously-thrown ones. This seems to be most easily handled by pushing the creation of ObservableDeferreds down from CacheDescriptor to the Cache.
This commit is contained in:
parent
f16aa3a44b
commit
618bd1ee76
3 changed files with 130 additions and 35 deletions
1
changelog.d/5749.misc
Normal file
1
changelog.d/5749.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix some error cases in the caching layer.
|
|
@ -19,8 +19,7 @@ import logging
|
|||
import threading
|
||||
from collections import namedtuple
|
||||
|
||||
import six
|
||||
from six import itervalues, string_types
|
||||
from six import itervalues
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
|
@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
|
|||
from synapse.util.caches import get_cache_factor_for
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
from synapse.util.stringutils import to_ascii
|
||||
|
||||
from . import register_cache
|
||||
|
||||
|
@ -124,7 +122,7 @@ class Cache(object):
|
|||
update_metrics (bool): whether to update the cache hit rate metrics
|
||||
|
||||
Returns:
|
||||
Either a Deferred or the raw result
|
||||
Either an ObservableDeferred or the raw result
|
||||
"""
|
||||
callbacks = [callback] if callback else []
|
||||
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
||||
|
@ -148,9 +146,14 @@ class Cache(object):
|
|||
return default
|
||||
|
||||
def set(self, key, value, callback=None):
|
||||
if not isinstance(value, defer.Deferred):
|
||||
raise TypeError("not a Deferred")
|
||||
|
||||
callbacks = [callback] if callback else []
|
||||
self.check_thread()
|
||||
entry = CacheEntry(deferred=value, callbacks=callbacks)
|
||||
observable = ObservableDeferred(value, consumeErrors=True)
|
||||
observer = defer.maybeDeferred(observable.observe)
|
||||
entry = CacheEntry(deferred=observable, callbacks=callbacks)
|
||||
|
||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if existing_entry:
|
||||
|
@ -158,11 +161,16 @@ class Cache(object):
|
|||
|
||||
self._pending_deferred_cache[key] = entry
|
||||
|
||||
def shuffle(result):
|
||||
def compare_and_pop():
|
||||
"""Check if our entry is still the one in _pending_deferred_cache, and
|
||||
if so, pop it.
|
||||
|
||||
Returns true if the entries matched.
|
||||
"""
|
||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||
if existing_entry is entry:
|
||||
self.cache.set(key, result, entry.callbacks)
|
||||
else:
|
||||
return True
|
||||
|
||||
# oops, the _pending_deferred_cache has been updated since
|
||||
# we started our query, so we are out of date.
|
||||
#
|
||||
|
@ -172,6 +180,12 @@ class Cache(object):
|
|||
if existing_entry is not None:
|
||||
self._pending_deferred_cache[key] = existing_entry
|
||||
|
||||
return False
|
||||
|
||||
def cb(result):
|
||||
if compare_and_pop():
|
||||
self.cache.set(key, result, entry.callbacks)
|
||||
else:
|
||||
# we're not going to put this entry into the cache, so need
|
||||
# to make sure that the invalidation callbacks are called.
|
||||
# That was probably done when _pending_deferred_cache was
|
||||
|
@ -179,9 +193,16 @@ class Cache(object):
|
|||
# `invalidate` being previously called, in which case it may
|
||||
# not have been. Either way, let's double-check now.
|
||||
entry.invalidate()
|
||||
return result
|
||||
|
||||
entry.deferred.addCallback(shuffle)
|
||||
def eb(_fail):
|
||||
compare_and_pop()
|
||||
entry.invalidate()
|
||||
|
||||
# once the deferred completes, we can move the entry from the
|
||||
# _pending_deferred_cache to the real cache.
|
||||
#
|
||||
observer.addCallbacks(cb, eb)
|
||||
return observable
|
||||
|
||||
def prefill(self, key, value, callback=None):
|
||||
callbacks = [callback] if callback else []
|
||||
|
@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
|
|||
|
||||
ret.addErrback(onErr)
|
||||
|
||||
# If our cache_key is a string on py2, try to convert to ascii
|
||||
# to save a bit of space in large caches. Py3 does this
|
||||
# internally automatically.
|
||||
if six.PY2 and isinstance(cache_key, string_types):
|
||||
cache_key = to_ascii(cache_key)
|
||||
|
||||
result_d = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
observer = result_d.observe()
|
||||
|
||||
if isinstance(observer, defer.Deferred):
|
||||
return make_deferred_yieldable(observer)
|
||||
else:
|
||||
return observer
|
||||
|
||||
if self.num_args == 1:
|
||||
wrapped.invalidate = lambda key: cache.invalidate(key[0])
|
||||
|
@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
missing.add(arg)
|
||||
|
||||
if missing:
|
||||
# we need an observable deferred for each entry in the list,
|
||||
# we need a deferred for each entry in the list,
|
||||
# which we put in the cache. Each deferred resolves with the
|
||||
# relevant result for that key.
|
||||
deferreds_map = {}
|
||||
|
@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
|||
deferred = defer.Deferred()
|
||||
deferreds_map[arg] = deferred
|
||||
key = arg_to_cache_key(arg)
|
||||
observable = ObservableDeferred(deferred)
|
||||
cache.set(key, observable, callback=invalidate_callback)
|
||||
cache.set(key, deferred, callback=invalidate_callback)
|
||||
|
||||
def complete_all(res):
|
||||
# the wrapped function has completed. It returns a
|
||||
|
|
|
@ -27,6 +27,7 @@ from synapse.logging.context import (
|
|||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.util.caches import descriptors
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
|
|||
d2 = defer.Deferred()
|
||||
cache.set("key2", d2, partial(record_callback, 1))
|
||||
|
||||
# lookup should return the deferreds
|
||||
self.assertIs(cache.get("key1"), d1)
|
||||
self.assertIs(cache.get("key2"), d2)
|
||||
# lookup should return observable deferreds
|
||||
self.assertFalse(cache.get("key1").has_called())
|
||||
self.assertFalse(cache.get("key2").has_called())
|
||||
|
||||
# let one of the lookups complete
|
||||
d2.callback("result2")
|
||||
|
||||
# for now at least, the cache will return real results rather than an
|
||||
# observabledeferred
|
||||
self.assertEqual(cache.get("key2"), "result2")
|
||||
|
||||
# now do the invalidation
|
||||
|
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
self.assertEqual(r, "chips")
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
def test_cache_with_sync_exception(self):
|
||||
"""If the wrapped function throws synchronously, things should continue to work
|
||||
"""
|
||||
|
||||
class Cls(object):
|
||||
@cached()
|
||||
def fn(self, arg1):
|
||||
raise SynapseError(100, "mai spoon iz too big!!1")
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# this should fail immediately
|
||||
d = obj.fn(1)
|
||||
self.failureResultOf(d, SynapseError)
|
||||
|
||||
# ... leaving the cache empty
|
||||
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||
|
||||
# and a second call should result in a second exception
|
||||
d = obj.fn(1)
|
||||
self.failureResultOf(d, SynapseError)
|
||||
|
||||
def test_cache_logcontexts(self):
|
||||
"""Check that logcontexts are set and restored correctly when
|
||||
using the cache."""
|
||||
|
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEqual(LoggingContext.current_context(), c1)
|
||||
|
||||
# the cache should now be empty
|
||||
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# set off a deferred which will do a cache lookup
|
||||
|
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
self.assertEqual(r, "chips")
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
def test_cache_iterable(self):
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached(iterable=True)
|
||||
def fn(self, arg1, arg2):
|
||||
return self.mock(arg1, arg2)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
obj.mock.return_value = ["spam", "eggs"]
|
||||
r = obj.fn(1, 2)
|
||||
self.assertEqual(r, ["spam", "eggs"])
|
||||
obj.mock.assert_called_once_with(1, 2)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = ["chips"]
|
||||
r = obj.fn(1, 3)
|
||||
self.assertEqual(r, ["chips"])
|
||||
obj.mock.assert_called_once_with(1, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# the two values should now be cached
|
||||
self.assertEqual(len(obj.fn.cache.cache), 3)
|
||||
|
||||
r = obj.fn(1, 2)
|
||||
self.assertEqual(r, ["spam", "eggs"])
|
||||
r = obj.fn(1, 3)
|
||||
self.assertEqual(r, ["chips"])
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
def test_cache_iterable_with_sync_exception(self):
|
||||
"""If the wrapped function throws synchronously, things should continue to work
|
||||
"""
|
||||
|
||||
class Cls(object):
|
||||
@descriptors.cached(iterable=True)
|
||||
def fn(self, arg1):
|
||||
raise SynapseError(100, "mai spoon iz too big!!1")
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# this should fail immediately
|
||||
d = obj.fn(1)
|
||||
self.failureResultOf(d, SynapseError)
|
||||
|
||||
# ... leaving the cache empty
|
||||
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||
|
||||
# and a second call should result in a second exception
|
||||
d = obj.fn(1)
|
||||
self.failureResultOf(d, SynapseError)
|
||||
|
||||
|
||||
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
|
|
Loading…
Add table
Reference in a new issue