diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst index 8d04a973de..eb1784e700 100644 --- a/docs/log_contexts.rst +++ b/docs/log_contexts.rst @@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with This technique works equally for external functions which return deferreds, or deferreds we have made ourselves. -XXX: think this is what ``preserve_context_over_deferred`` is supposed to do, -though it is broken, in that it only restores the logcontext for the duration -of the callbacks, which doesn't comply with the logcontext rules. +You can also use ``logcontext.make_deferred_yieldable``, which just does the +boilerplate for you, so the above could be written: + +.. code:: python + + def sleep(seconds): + return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds)) + Fire-and-forget --------------- diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 19595df422..5c30ed235d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -15,12 +15,9 @@ import logging from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError +from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn -) from . import DEBUG_CACHES, register_cache @@ -328,11 +325,9 @@ class CacheDescriptor(_CacheDescriptorBase): defer.returnValue(cached_result) observer.addCallback(check_result) - return preserve_context_over_deferred(observer) except KeyError: ret = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) @@ -342,10 +337,11 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, ret, callback=invalidate_callback) + result_d = ObservableDeferred(ret, consumeErrors=True) + cache.set(cache_key, result_d, callback=invalidate_callback) + observer = result_d.observe() - return preserve_context_over_deferred(ret.observe()) + return logcontext.make_deferred_yieldable(observer) wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all @@ -362,7 +358,11 @@ class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. + the list of missing keys to the wrapped function. + + Once wrapped, the function returns either a Deferred which resolves to + the list of results, or (if all results were cached), just the list of + results. """ def __init__(self, orig, cached_method_name, list_name, num_args=None, @@ -433,8 +433,7 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), **args_to_call ) @@ -443,8 +442,7 @@ class CacheListDescriptor(_CacheDescriptorBase): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - with PreserveLoggingContext(): - observer = ret_d.observe() + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -471,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase): results.update(res) return results - return preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( cached_defers.values(), consumeErrors=True, ).addCallback(update_results_dict).addErrback( diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index ff67b1d794..857afee7cb 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs): def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. + + Deprecated: this almost certainly doesn't do want you want, ie make + the deferred follow the synapse logcontext rules: try + ``make_deferred_yieldable`` instead. """ if context is None: context = LoggingContext.current_context() @@ -359,6 +363,25 @@ def preserve_fn(f): return g +@defer.inlineCallbacks +def make_deferred_yieldable(deferred): + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed (or is not actually a Deferred), essentially + does nothing (just returns another completed deferred with the + result/failure). + + If the deferred has not yet completed, resets the logcontext before + returning a deferred. Then, when the deferred completes, restores the + current logcontext before running callbacks/errbacks. + + (This is more-or-less the opposite operation to preserve_fn.) + """ + with PreserveLoggingContext(): + r = yield deferred + defer.returnValue(r) + + # modules to ignore in `logcontext_tracer` _to_ignore = [ "synapse.util.logcontext", diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 419281054d..4414e86771 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -12,11 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + import mock +from synapse.api.errors import SynapseError +from synapse.util import async +from synapse.util import logcontext from twisted.internet import defer from synapse.util.caches import descriptors from tests import unittest +logger = logging.getLogger(__name__) + class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -84,3 +91,87 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 5) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + def test_cache_logcontexts(self): + """Check that logcontexts are set and restored correctly when + using the cache.""" + + complete_lookup = defer.Deferred() + + class Cls(object): + @descriptors.cached() + def fn(self, arg1): + @defer.inlineCallbacks + def inner_fn(): + with logcontext.PreserveLoggingContext(): + yield complete_lookup + defer.returnValue(1) + + return inner_fn() + + @defer.inlineCallbacks + def do_lookup(): + with logcontext.LoggingContext() as c1: + c1.name = "c1" + r = yield obj.fn(1) + self.assertEqual(logcontext.LoggingContext.current_context(), + c1) + defer.returnValue(r) + + def check_result(r): + self.assertEqual(r, 1) + + obj = Cls() + + # set off a deferred which will do a cache lookup + d1 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + d1.addCallback(check_result) + + # and another + d2 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + d2.addCallback(check_result) + + # let the lookup complete + complete_lookup.callback(None) + + return defer.gatherResults([d1, d2]) + + def test_cache_logcontexts_with_exception(self): + """Check that the cache sets and restores logcontexts correctly when + the lookup function throws an exception""" + + class Cls(object): + @descriptors.cached() + def fn(self, arg1): + @defer.inlineCallbacks + def inner_fn(): + yield async.run_on_reactor() + raise SynapseError(400, "blah") + + return inner_fn() + + @defer.inlineCallbacks + def do_lookup(): + with logcontext.LoggingContext() as c1: + c1.name = "c1" + try: + yield obj.fn(1) + self.fail("No exception thrown") + except SynapseError: + pass + + self.assertEqual(logcontext.LoggingContext.current_context(), + c1) + + obj = Cls() + + # set off a deferred which will do a cache lookup + d1 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + + return d1