0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-05-20 12:33:46 +02:00

Fix some error cases in the caching layer. (#5749)

There was some inconsistent behaviour in the caching layer around how
exceptions were handled - particularly synchronously-thrown ones.

This seems to be most easily handled by pushing the creation of
ObservableDeferreds down from CacheDescriptor to the Cache.
This commit is contained in:
Richard van der Hoff 2019-07-25 15:59:45 +01:00 committed by GitHub
parent f16aa3a44b
commit 618bd1ee76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 35 deletions

1
changelog.d/5749.misc Normal file
View file

@ -0,0 +1 @@
Fix some error cases in the caching layer.

View file

@ -19,8 +19,7 @@ import logging
import threading
from collections import namedtuple
import six
from six import itervalues, string_types
from six import itervalues
from prometheus_client import Gauge
@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
@ -124,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either a Deferred or the raw result
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@ -148,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks)
observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@ -158,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@ -179,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
# If our cache_key is a string on py2, try to convert to ascii
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer)
else:
return observer
return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
# we need an observable deferred for each entry in the list,
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a

View file

@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
from tests import unittest
@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls()
# set off a deferred which will do a cache lookup
@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks