Clean up some LoggingContext stuff (#7120)

* Pull Sentinel out of LoggingContext

... and drop a few unnecessary references to it

* Factor out LoggingContext.current_context

move `current_context` and `set_context` out to top-level functions.

Mostly this means that I can more easily trace what's actually referring to
LoggingContext, but I think it's generally neater.

* move copy-to-parent into `stop`

this really just makes `start` and `stop` more symetric. It also means that it
behaves correctly if you manually `set_log_context` rather than using the
context manager.

* Replace `LoggingContext.alive` with `finished`

Turn `alive` into `finished` and make it a bit better defined.
This commit is contained in:
Richard van der Hoff 2020-03-24 14:45:33 +00:00 committed by GitHub
parent 1fcf9c6f95
commit 39230d2171
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 232 additions and 222 deletions

1
changelog.d/7120.misc Normal file
View file

@ -0,0 +1 @@
Clean up some LoggingContext code.

View file

@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets
def handle_request(request_id): def handle_request(request_id):
request_context = context.LoggingContext() request_context = context.LoggingContext()
calling_context = context.LoggingContext.current_context() calling_context = context.set_current_context(request_context)
context.LoggingContext.set_current_context(request_context)
try: try:
request_context.request = request_id request_context.request = request_id
do_request_handling() do_request_handling()
logger.debug("finished") logger.debug("finished")
finally: finally:
context.LoggingContext.set_current_context(calling_context) context.set_current_context(calling_context)
def do_request_handling(): def do_request_handling():
logger.debug("phew") # this will be logged against request_id logger.debug("phew") # this will be logged against request_id

View file

@ -43,8 +43,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.logging.context import ( from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
preserve_fn, preserve_fn,
run_in_background, run_in_background,
@ -236,7 +236,7 @@ class Keyring(object):
""" """
try: try:
ctx = LoggingContext.current_context() ctx = current_context()
# map from server name to a set of outstanding request ids # map from server name to a set of outstanding request ids
server_to_request_ids = {} server_to_request_ids = {}

View file

@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
@ -78,7 +78,7 @@ class FederationBase(object):
""" """
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
ctx = LoggingContext.current_context() ctx = current_context()
def callback(_, pdu: EventBase): def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):

View file

@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -301,7 +301,7 @@ class SyncHandler(object):
else: else:
sync_type = "incremental_sync" sync_type = "incremental_sync"
context = LoggingContext.current_context() context = current_context()
if context: if context:
context.tag = sync_type context.tag = sync_type

View file

@ -19,7 +19,7 @@ import threading
from prometheus_client.core import Counter, Histogram from prometheus_client.core import Counter, Histogram
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,7 +148,7 @@ LaterGauge(
class RequestMetrics(object): class RequestMetrics(object):
def start(self, time_sec, name, method): def start(self, time_sec, name, method):
self.start = time_sec self.start = time_sec
self.start_context = LoggingContext.current_context() self.start_context = current_context()
self.name = name self.name = name
self.method = method self.method = method
@ -163,7 +163,7 @@ class RequestMetrics(object):
with _in_flight_requests_lock: with _in_flight_requests_lock:
_in_flight_requests.discard(self) _in_flight_requests.discard(self)
context = LoggingContext.current_context() context = current_context()
tag = "" tag = ""
if context: if context:

View file

@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver, TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver, TerseJSONToTCPLogObserver,
) )
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context
def stdlib_log_level_to_twisted(level: str) -> LogLevel: def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@ -86,7 +86,7 @@ class LogContextObserver(object):
].startswith("Timing out client"): ].startswith("Timing out client"):
return return
context = LoggingContext.current_context() context = current_context()
# Copy the context information to the log event. # Copy the context information to the log event.
if context is not None: if context is not None:

View file

@ -175,7 +175,54 @@ class ContextResourceUsage(object):
return res return res
LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"] LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
class _Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "finished", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.finished = False
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
SENTINEL_CONTEXT = _Sentinel()
class LoggingContext(object): class LoggingContext(object):
@ -199,76 +246,33 @@ class LoggingContext(object):
"_resource_usage", "_resource_usage",
"usage_start", "usage_start",
"main_thread", "main_thread",
"alive", "finished",
"request", "request",
"tag", "tag",
"scope", "scope",
] ]
thread_local = threading.local()
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = ["previous_context", "alive", "request", "scope", "tag"]
def __init__(self) -> None:
# Minimal set for compatibility with LoggingContext
self.previous_context = None
self.alive = None
self.request = None
self.scope = None
self.tag = None
def __str__(self):
return "sentinel"
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_sec):
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
return False
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
def __init__(self, name=None, parent_context=None, request=None) -> None: def __init__(self, name=None, parent_context=None, request=None) -> None:
self.previous_context = LoggingContext.current_context() self.previous_context = current_context()
self.name = name self.name = name
# track the resources used by this context so far # track the resources used by this context so far
self._resource_usage = ContextResourceUsage() self._resource_usage = ContextResourceUsage()
# If alive has the thread resource usage when the logcontext last # The thread resource usage when the logcontext became active. None
# became active. # if the context is not currently active.
self.usage_start = None self.usage_start = None
self.main_thread = get_thread_id() self.main_thread = get_thread_id()
self.request = None self.request = None
self.tag = "" self.tag = ""
self.alive = True
self.scope = None # type: Optional[_LogContextScope] self.scope = None # type: Optional[_LogContextScope]
# keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should
# be finished, and that re-activating it would suggest an error).
self.finished = False
self.parent_context = parent_context self.parent_context = parent_context
if self.parent_context is not None: if self.parent_context is not None:
@ -283,44 +287,15 @@ class LoggingContext(object):
return str(self.request) return str(self.request)
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
context.start()
return current
def __enter__(self) -> "LoggingContext": def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
old_context = self.set_current_context(self) old_context = set_current_context(self)
if self.previous_context != old_context: if self.previous_context != old_context:
logger.warning( logger.warning(
"Expected previous context %r, found %r", "Expected previous context %r, found %r",
self.previous_context, self.previous_context,
old_context, old_context,
) )
self.alive = True
return self return self
def __exit__(self, type, value, traceback) -> None: def __exit__(self, type, value, traceback) -> None:
@ -329,24 +304,19 @@ class LoggingContext(object):
Returns: Returns:
None to avoid suppressing any exceptions that were thrown. None to avoid suppressing any exceptions that were thrown.
""" """
current = self.set_current_context(self.previous_context) current = set_current_context(self.previous_context)
if current is not self: if current is not self:
if current is self.sentinel: if current is SENTINEL_CONTEXT:
logger.warning("Expected logging context %s was lost", self) logger.warning("Expected logging context %s was lost", self)
else: else:
logger.warning( logger.warning(
"Expected logging context %s but found %s", self, current "Expected logging context %s but found %s", self, current
) )
self.alive = False
# if we have a parent, pass our CPU usage stats on # the fact that we are here suggests that the caller thinks that everything
if self.parent_context is not None and hasattr( # is done and dusted for this logcontext, and further activity will not get
self.parent_context, "_resource_usage" # recorded against the correct metrics.
): self.finished = True
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def copy_to(self, record) -> None: def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or """Copy logging fields from this context to a log record or
@ -371,9 +341,14 @@ class LoggingContext(object):
logger.warning("Started logcontext %s on different thread", self) logger.warning("Started logcontext %s on different thread", self)
return return
if self.finished:
logger.warning("Re-starting finished log context %s", self)
# If we haven't already started record the thread resource usage so # If we haven't already started record the thread resource usage so
# far # far
if not self.usage_start: if self.usage_start:
logger.warning("Re-starting already-active log context %s", self)
else:
self.usage_start = get_thread_resource_usage() self.usage_start = get_thread_resource_usage()
def stop(self) -> None: def stop(self) -> None:
@ -396,6 +371,15 @@ class LoggingContext(object):
self.usage_start = None self.usage_start = None
# if we have a parent, pass our CPU usage stats on
if self.parent_context is not None and hasattr(
self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
# reset them in case we get entered again
self._resource_usage.reset()
def get_resource_usage(self) -> ContextResourceUsage: def get_resource_usage(self) -> ContextResourceUsage:
"""Get resources used by this logcontext so far. """Get resources used by this logcontext so far.
@ -409,7 +393,7 @@ class LoggingContext(object):
# If we are on the correct thread and we're currently running then we # If we are on the correct thread and we're currently running then we
# can include resource usage so far. # can include resource usage so far.
is_main_thread = get_thread_id() == self.main_thread is_main_thread = get_thread_id() == self.main_thread
if self.alive and self.usage_start and is_main_thread: if self.usage_start and is_main_thread:
utime_delta, stime_delta = self._get_cputime() utime_delta, stime_delta = self._get_cputime()
res.ru_utime += utime_delta res.ru_utime += utime_delta
res.ru_stime += stime_delta res.ru_stime += stime_delta
@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
Returns: Returns:
True to include the record in the log output. True to include the record in the log output.
""" """
context = LoggingContext.current_context() context = current_context()
for key, value in self.defaults.items(): for key, value in self.defaults.items():
setattr(record, key, value) setattr(record, key, value)
@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"] __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None: def __init__(
if new_context is None: self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel ) -> None:
else: self.new_context = new_context
self.new_context = new_context
def __enter__(self) -> None: def __enter__(self) -> None:
"""Captures the current logging context""" """Captures the current logging context"""
self.current_context = LoggingContext.set_current_context(self.new_context) self.current_context = set_current_context(self.new_context)
if self.current_context: if self.current_context:
self.has_parent = self.current_context.previous_context is not None self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback) -> None: def __exit__(self, type, value, traceback) -> None:
"""Restores the current logging context""" """Restores the current logging context"""
context = LoggingContext.set_current_context(self.current_context) context = set_current_context(self.current_context)
if context != self.new_context: if context != self.new_context:
if context is LoggingContext.sentinel: if not context:
logger.warning("Expected logging context %s was lost", self.new_context) logger.warning("Expected logging context %s was lost", self.new_context)
else: else:
logger.warning( logger.warning(
@ -541,9 +522,30 @@ class PreserveLoggingContext(object):
context, context,
) )
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive: _thread_local = threading.local()
logger.debug("Restoring dead context: %s", self.current_context) _thread_local.current_context = SENTINEL_CONTEXT
def current_context() -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage"""
return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = current_context()
if current is not context:
current.stop()
_thread_local.current_context = context
context.start()
return current
def nested_logging_context( def nested_logging_context(
@ -572,7 +574,7 @@ def nested_logging_context(
if parent_context is not None: if parent_context is not None:
context = parent_context # type: LoggingContextOrSentinel context = parent_context # type: LoggingContextOrSentinel
else: else:
context = LoggingContext.current_context() context = current_context()
return LoggingContext( return LoggingContext(
parent_context=context, request=str(context.request) + "-" + suffix parent_context=context, request=str(context.request) + "-" + suffix
) )
@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs):
CRITICAL error about an unhandled error will be logged without much CRITICAL error about an unhandled error will be logged without much
indication about where it came from. indication about where it came from.
""" """
current = LoggingContext.current_context() current = current_context()
try: try:
res = f(*args, **kwargs) res = f(*args, **kwargs)
except: # noqa: E722 except: # noqa: E722
@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs):
# The function may have reset the context before returning, so # The function may have reset the context before returning, so
# we need to restore it now. # we need to restore it now.
ctx = LoggingContext.set_current_context(current) ctx = set_current_context(current)
# The original context will be restored when the deferred # The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will # completes, but there is nothing waiting for it, so it will
@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred):
# ok, we can't be sure that a yield won't block, so let's reset the # ok, we can't be sure that a yield won't block, so let's reset the
# logcontext, and add a callback to the deferred to restore it. # logcontext, and add a callback to the deferred to restore it.
prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) prev_context = set_current_context(SENTINEL_CONTEXT)
deferred.addBoth(_set_context_cb, prev_context) deferred.addBoth(_set_context_cb, prev_context)
return deferred return deferred
@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT")
def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
"""A callback function which just sets the logging context""" """A callback function which just sets the logging context"""
LoggingContext.set_current_context(context) set_current_context(context)
return result return result
@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
Deferred: A Deferred which fires a callback with the result of `f`, or an Deferred: A Deferred which fires a callback with the result of `f`, or an
errback if `f` throws an exception. errback if `f` throws an exception.
""" """
logcontext = LoggingContext.current_context() logcontext = current_context()
def g(): def g():
with LoggingContext(parent_context=logcontext): with LoggingContext(parent_context=logcontext):

View file

@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
import twisted import twisted
from synapse.logging.context import LoggingContext, nested_logging_context from synapse.logging.context import current_context, nested_logging_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
(Scope) : the Scope that is active, or None if not (Scope) : the Scope that is active, or None if not
available. available.
""" """
ctx = LoggingContext.current_context() ctx = current_context()
if ctx is LoggingContext.sentinel: return ctx.scope
return None
else:
return ctx.scope
def activate(self, span, finish_on_close): def activate(self, span, finish_on_close):
""" """
@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
""" """
enter_logcontext = False enter_logcontext = False
ctx = LoggingContext.current_context() ctx = current_context()
if ctx is LoggingContext.sentinel: if not ctx:
# We don't want this scope to affect. # We don't want this scope to affect.
logger.error("Tried to activate scope outside of loggingcontext") logger.error("Tried to activate scope outside of loggingcontext")
return Scope(None, span) return Scope(None, span)

View file

@ -35,7 +35,7 @@ from synapse.api.room_versions import (
) )
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import Database
@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map] missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids: if missing_events_ids:
log_ctx = LoggingContext.current_context() log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids)) log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows # Note that _get_events_from_db is also responsible for turning db rows

View file

@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import ( from synapse.logging.context import (
LoggingContext, LoggingContext,
LoggingContextOrSentinel, LoggingContextOrSentinel,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -483,7 +484,7 @@ class Database(object):
end = monotonic_time() end = monotonic_time()
duration = end - start duration = end - start
LoggingContext.current_context().add_database_transaction(duration) current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@ -510,7 +511,7 @@ class Database(object):
after_callbacks = [] # type: List[_CallbackListEntry] after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel: if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc) logger.warning("Starting db txn '%s' from sentinel context", desc)
try: try:
@ -547,10 +548,8 @@ class Database(object):
Returns: Returns:
Deferred: The result of func Deferred: The result of func
""" """
parent_context = ( parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
LoggingContext.current_context() if not parent_context:
) # type: Optional[LoggingContextOrSentinel]
if parent_context == LoggingContext.sentinel:
logger.warning( logger.warning(
"Starting db connection from sentinel context: metrics will be lost" "Starting db connection from sentinel context: metrics will be lost"
) )

View file

@ -21,7 +21,7 @@ from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge from synapse.metrics import InFlightGauge
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -106,7 +106,7 @@ class Measure(object):
raise RuntimeError("Measure() objects cannot be re-used") raise RuntimeError("Measure() objects cannot be re-used")
self.start = self.clock.time() self.start = self.clock.time()
parent_context = LoggingContext.current_context() parent_context = current_context()
self._logging_context = LoggingContext( self._logging_context = LoggingContext(
"Measure[%s]" % (self.name,), parent_context "Measure[%s]" % (self.name,), parent_context
) )

View file

@ -32,7 +32,7 @@ def do_patch():
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
""" """
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context
global _already_patched global _already_patched
@ -43,35 +43,35 @@ def do_patch():
def new_inline_callbacks(f): def new_inline_callbacks(f):
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
start_context = LoggingContext.current_context() start_context = current_context()
changes = [] # type: List[str] changes = [] # type: List[str]
orig = orig_inline_callbacks(_check_yield_points(f, changes)) orig = orig_inline_callbacks(_check_yield_points(f, changes))
try: try:
res = orig(*args, **kwargs) res = orig(*args, **kwargs)
except Exception: except Exception:
if LoggingContext.current_context() != start_context: if current_context() != start_context:
for err in changes: for err in changes:
print(err, file=sys.stderr) print(err, file=sys.stderr)
err = "%s changed context from %s to %s on exception" % ( err = "%s changed context from %s to %s on exception" % (
f, f,
start_context, start_context,
LoggingContext.current_context(), current_context(),
) )
print(err, file=sys.stderr) print(err, file=sys.stderr)
raise Exception(err) raise Exception(err)
raise raise
if not isinstance(res, Deferred) or res.called: if not isinstance(res, Deferred) or res.called:
if LoggingContext.current_context() != start_context: if current_context() != start_context:
for err in changes: for err in changes:
print(err, file=sys.stderr) print(err, file=sys.stderr)
err = "Completed %s changed context from %s to %s" % ( err = "Completed %s changed context from %s to %s" % (
f, f,
start_context, start_context,
LoggingContext.current_context(), current_context(),
) )
# print the error to stderr because otherwise all we # print the error to stderr because otherwise all we
# see in travis-ci is the 500 error # see in travis-ci is the 500 error
@ -79,23 +79,23 @@ def do_patch():
raise Exception(err) raise Exception(err)
return res return res
if LoggingContext.current_context() != LoggingContext.sentinel: if current_context():
err = ( err = (
"%s returned incomplete deferred in non-sentinel context " "%s returned incomplete deferred in non-sentinel context "
"%s (start was %s)" "%s (start was %s)"
) % (f, LoggingContext.current_context(), start_context) ) % (f, current_context(), start_context)
print(err, file=sys.stderr) print(err, file=sys.stderr)
raise Exception(err) raise Exception(err)
def check_ctx(r): def check_ctx(r):
if LoggingContext.current_context() != start_context: if current_context() != start_context:
for err in changes: for err in changes:
print(err, file=sys.stderr) print(err, file=sys.stderr)
err = "%s completion of %s changed context from %s to %s" % ( err = "%s completion of %s changed context from %s to %s" % (
"Failure" if isinstance(r, Failure) else "Success", "Failure" if isinstance(r, Failure) else "Success",
f, f,
start_context, start_context,
LoggingContext.current_context(), current_context(),
) )
print(err, file=sys.stderr) print(err, file=sys.stderr)
raise Exception(err) raise Exception(err)
@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
function function
""" """
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context
@functools.wraps(f) @functools.wraps(f)
def check_yield_points_inner(*args, **kwargs): def check_yield_points_inner(*args, **kwargs):
@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
last_yield_line_no = gen.gi_frame.f_lineno last_yield_line_no = gen.gi_frame.f_lineno
result = None # type: Any result = None # type: Any
while True: while True:
expected_context = LoggingContext.current_context() expected_context = current_context()
try: try:
isFailure = isinstance(result, Failure) isFailure = isinstance(result, Failure)
@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
else: else:
d = gen.send(result) d = gen.send(result)
except (StopIteration, defer._DefGen_Return) as e: except (StopIteration, defer._DefGen_Return) as e:
if LoggingContext.current_context() != expected_context: if current_context() != expected_context:
# This happens when the context is lost sometime *after* the # This happens when the context is lost sometime *after* the
# final yield and returning. E.g. we forgot to yield on a # final yield and returning. E.g. we forgot to yield on a
# function that returns a deferred. # function that returns a deferred.
@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% ( % (
f.__qualname__, f.__qualname__,
expected_context, expected_context,
LoggingContext.current_context(), current_context(),
f.__code__.co_filename, f.__code__.co_filename,
last_yield_line_no, last_yield_line_no,
) )
@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
# This happens if we yield on a deferred that doesn't follow # This happens if we yield on a deferred that doesn't follow
# the log context rules without wrapping in a `make_deferred_yieldable`. # the log context rules without wrapping in a `make_deferred_yieldable`.
# We raise here as this should never happen. # We raise here as this should never happen.
if LoggingContext.current_context() is not LoggingContext.sentinel: if current_context():
err = ( err = (
"%s yielded with context %s rather than sentinel," "%s yielded with context %s rather than sentinel,"
" yielded on line %d in %s" " yielded on line %d in %s"
% ( % (
frame.f_code.co_name, frame.f_code.co_name,
LoggingContext.current_context(), current_context(),
frame.f_lineno, frame.f_lineno,
frame.f_code.co_filename, frame.f_code.co_filename,
) )
@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
except Exception as e: except Exception as e:
result = Failure(e) result = Failure(e)
if LoggingContext.current_context() != expected_context: if current_context() != expected_context:
# This happens because the context is lost sometime *after* the # This happens because the context is lost sometime *after* the
# previous yield and *after* the current yield. E.g. the # previous yield and *after* the current yield. E.g. the
@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
% ( % (
frame.f_code.co_name, frame.f_code.co_name,
expected_context, expected_context,
LoggingContext.current_context(), current_context(),
last_yield_line_no, last_yield_line_no,
frame.f_lineno, frame.f_lineno,
frame.f_code.co_filename, frame.f_code.co_filename,

View file

@ -34,6 +34,7 @@ from synapse.crypto.keyring import (
from synapse.logging.context import ( from synapse.logging.context import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
) )
def check_context(self, _, expected): def check_context(self, _, expected):
self.assertEquals( self.assertEquals(getattr(current_context(), "request", None), expected)
getattr(LoggingContext.current_context(), "request", None), expected
)
def test_verify_json_objects_for_server_awaits_previous_requests(self): def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_perspectives(**kwargs): def get_perspectives(**kwargs):
self.assertEquals(LoggingContext.current_context().request, "11") self.assertEquals(current_context().request, "11")
with PreserveLoggingContext(): with PreserveLoggingContext():
yield persp_deferred yield persp_deferred
return persp_resp return persp_resp

View file

@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver, WellKnownResolver,
_cache_period_from_headers, _cache_period_from_headers,
) )
from synapse.logging.context import LoggingContext from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.util.caches.ttlcache import TTLCache from synapse.util.caches.ttlcache import TTLCache
from tests import unittest from tests import unittest
@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(fetch_d) self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel # should have reset logcontext to the sentinel
_check_logcontext(LoggingContext.sentinel) _check_logcontext(SENTINEL_CONTEXT)
try: try:
fetch_res = yield fetch_d fetch_res = yield fetch_d
@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
def _check_logcontext(context): def _check_logcontext(context):
current = LoggingContext.current_context() current = current_context()
if current is not context: if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) raise AssertionError("Expected logcontext %s but was %s" % (context, current))

View file

@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver from synapse.http.federation.srv_resolver import SrvResolver
from synapse.logging.context import LoggingContext from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests import unittest from tests import unittest
from tests.utils import MockClock from tests.utils import MockClock
@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d) self.assertNoResult(resolve_d)
# should have reset to the sentinel context # should have reset to the sentinel context
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d result = yield resolve_d
# should have restored our context # should have restored our context
self.assertIs(LoggingContext.current_context(), ctx) self.assertIs(current_context(), ctx)
return result return result

View file

@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient, MatrixFederationHttpClient,
MatrixFederationRequest, MatrixFederationRequest,
) )
from synapse.logging.context import LoggingContext from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests.server import FakeTransport from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
def check_logcontext(context): def check_logcontext(context):
current = LoggingContext.current_context() current = current_context()
if current is not context: if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current)) raise AssertionError("Expected logcontext %s but was %s" % (context, current))
@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(fetch_d) self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel # should have reset logcontext to the sentinel
check_logcontext(LoggingContext.sentinel) check_logcontext(SENTINEL_CONTEXT)
try: try:
fetch_res = yield fetch_d fetch_res = yield fetch_d

View file

@ -2,7 +2,7 @@ from mock import Mock, call
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.logging.context import LoggingContext from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock from synapse.util import Clock
@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test(): def test():
with LoggingContext("c") as c1: with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertIs(LoggingContext.current_context(), c1) self.assertIs(current_context(), c1)
self.assertEqual(res, "yay") self.assertEqual(res, "yay")
# run the test twice in parallel # run the test twice in parallel
d = defer.gatherResults([test(), test()]) d = defer.gatherResults([test(), test()])
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) self.assertIs(current_context(), SENTINEL_CONTEXT)
yield d yield d
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_does_not_cache_exceptions(self): def test_does_not_cache_exceptions(self):
@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e: except Exception as e:
self.assertEqual(e.args[0], "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(current_context(), test_context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_does_not_cache_failures(self): def test_does_not_cache_failures(self):
@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb) yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e: except Exception as e:
self.assertEqual(e.args[0], "boo") self.assertEqual(e.args[0], "boo")
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb) res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response) self.assertEqual(res, self.mock_http_response)
self.assertIs(LoggingContext.current_context(), test_context) self.assertIs(current_context(), test_context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cleans_up(self): def test_cleans_up(self):

View file

@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import LoggingContext from synapse.logging.context import (
SENTINEL_CONTEXT,
current_context,
set_current_context,
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -97,10 +101,10 @@ class TestCase(unittest.TestCase):
def setUp(orig): def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest # if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off. # all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel: if current_context():
self.fail( self.fail(
"Test starting with non-sentinel logging context %s" "Test starting with non-sentinel logging context %s"
% (LoggingContext.current_context(),) % (current_context(),)
) )
old_level = logging.getLogger().level old_level = logging.getLogger().level
@ -122,7 +126,7 @@ class TestCase(unittest.TestCase):
# force a GC to workaround problems with deferreds leaking logcontexts when # force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs) # they are GCed (see the logcontext docs)
gc.collect() gc.collect()
LoggingContext.set_current_context(LoggingContext.sentinel) set_current_context(SENTINEL_CONTEXT)
return ret return ret

View file

@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import ( from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
with LoggingContext() as c1: with LoggingContext() as c1:
c1.name = "c1" c1.name = "c1"
r = yield obj.fn(1) r = yield obj.fn(1)
self.assertEqual(LoggingContext.current_context(), c1) self.assertEqual(current_context(), c1)
return r return r
def check_result(r): def check_result(r):
@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
d1 = do_lookup() d1 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) self.assertEqual(current_context(), SENTINEL_CONTEXT)
d1.addCallback(check_result) d1.addCallback(check_result)
# and another # and another
d2 = do_lookup() d2 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) self.assertEqual(current_context(), SENTINEL_CONTEXT)
d2.addCallback(check_result) d2.addCallback(check_result)
# let the lookup complete # let the lookup complete
@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
try: try:
d = obj.fn(1) d = obj.fn(1)
self.assertEqual( self.assertEqual(
LoggingContext.current_context(), LoggingContext.sentinel current_context(), SENTINEL_CONTEXT,
) )
yield d yield d
self.fail("No exception thrown") self.fail("No exception thrown")
except SynapseError: except SynapseError:
pass pass
self.assertEqual(LoggingContext.current_context(), c1) self.assertEqual(current_context(), c1)
# the cache should now be empty # the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0) self.assertEqual(len(obj.fn.cache.cache), 0)
@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
d1 = do_lookup() d1 = do_lookup()
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) self.assertEqual(current_context(), SENTINEL_CONTEXT)
return d1 return d1
@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True) @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2): def list_fn(self, args1, arg2):
assert LoggingContext.current_context().request == "c1" assert current_context().request == "c1"
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
yield run_on_reactor() yield run_on_reactor()
assert LoggingContext.current_context().request == "c1" assert current_context().request == "c1"
return self.mock(args1, arg2) return self.mock(args1, arg2)
with LoggingContext() as c1: with LoggingContext() as c1:
@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls() obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"} obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2) d1 = obj.list_fn([10, 20], 2)
self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1 r = yield d1
self.assertEqual(LoggingContext.current_context(), c1) self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2) obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"}) self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock() obj.mock.reset_mock()

View file

@ -16,7 +16,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock from twisted.internet.task import Clock
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
current_context,
)
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from tests.unittest import TestCase from tests.unittest import TestCase
@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext # the errbacks should be run in the test logcontext
def errback(res, deferred_name): def errback(res, deferred_name):
self.assertIs( self.assertIs(
LoggingContext.current_context(), current_context(),
context_one, context_one,
"errback %s run in unexpected logcontext %s" "errback %s run in unexpected logcontext %s"
% (deferred_name, LoggingContext.current_context()), % (deferred_name, current_context()),
) )
return res return res
@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase):
original_deferred.addErrback(errback, "orig") original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d) self.assertNoResult(timing_out_d)
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) self.assertIs(current_context(), SENTINEL_CONTEXT)
timing_out_d.addErrback(errback, "timingout") timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0,)) self.clock.pump((1.0,))
@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase):
blocking_was_cancelled[0], "non-completing deferred was not cancelled" blocking_was_cancelled[0], "non-completing deferred was not cancelled"
) )
self.failureResultOf(timing_out_d, defer.TimeoutError) self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(LoggingContext.current_context(), context_one) self.assertIs(current_context(), context_one)

View file

@ -19,7 +19,7 @@ from six.moves import range
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext, current_context
from synapse.util import Clock from synapse.util import Clock
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False): def func(i, sleep=False):
with LoggingContext("func(%s)" % i) as lc: with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")): with (yield linearizer.queue("")):
self.assertEqual(LoggingContext.current_context(), lc) self.assertEqual(current_context(), lc)
if sleep: if sleep:
yield Clock(reactor).sleep(0) yield Clock(reactor).sleep(0)
self.assertEqual(LoggingContext.current_context(), lc) self.assertEqual(current_context(), lc)
func(0, sleep=True) func(0, sleep=True)
for i in range(1, 100): for i in range(1, 100):

View file

@ -2,8 +2,10 @@ import twisted.python.failure
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.logging.context import ( from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
current_context,
make_deferred_yieldable, make_deferred_yieldable,
nested_logging_context, nested_logging_context,
run_in_background, run_in_background,
@ -15,7 +17,7 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
self.assertEquals(LoggingContext.current_context().request, value) self.assertEquals(current_context().request, value)
def test_with_context(self): def test_with_context(self):
with LoggingContext() as context_one: with LoggingContext() as context_one:
@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one") self._check_test_key("one")
def _test_run_in_background(self, function): def _test_run_in_background(self, function):
sentinel_context = LoggingContext.current_context() sentinel_context = current_context()
callback_completed = [False] callback_completed = [False]
@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back # make sure that the context was reset before it got thrown back
# into the reactor # into the reactor
try: try:
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
d2.callback(None) d2.callback(None)
except BaseException: except BaseException:
d2.errback(twisted.python.failure.Failure()) d2.errback(twisted.python.failure.Failure())
@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc(): async def testfunc():
self._check_test_key("one") self._check_test_key("one")
d = Clock(reactor).sleep(0) d = Clock(reactor).sleep(0)
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) self.assertIs(current_context(), SENTINEL_CONTEXT)
await d await d
self._check_test_key("one") self._check_test_key("one")
@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None) reactor.callLater(0, d.callback, None)
return d return d
sentinel_context = LoggingContext.current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.request = "one" context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
yield d1 yield d1
@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self): def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = LoggingContext.current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.request = "one" context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function()) d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
yield d1 yield d1
@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None) reactor.callLater(0, d.callback, None)
await d await d
sentinel_context = LoggingContext.current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.request = "one" context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
yield d1 yield d1

View file

@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.logging.context import LoggingContext from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.engines import PostgresEngine, create_engine
@ -493,10 +493,10 @@ class MockClock(object):
return self.time() * 1000 return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context() ctx = current_context()
def wrapped_callback(): def wrapped_callback():
LoggingContext.thread_local.current_context = current_context set_current_context(ctx)
callback(*args, **kwargs) callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False] t = [self.now + delay, wrapped_callback, False]