0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-21 20:18:20 +02:00

Record more information into structured logs. (#9654)

Records additional request information into the structured logs,
e.g. the requester, IP address, etc.
This commit is contained in:
Patrick Cloke 2021-04-08 08:01:14 -04:00 committed by GitHub
parent 0d87c6bd12
commit 48d44ab142
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 255 additions and 88 deletions

1
changelog.d/9654.feature Normal file
View file

@ -0,0 +1 @@
Include request information in structured logging output.

View file

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Type, Union from typing import Optional, Tuple, Type, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri from synapse.http import get_request_user_agent, redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import (
ContextRequest,
LoggingContext,
PreserveLoggingContext,
)
from synapse.types import Requester from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +67,7 @@ class SynapseRequest(Request):
# The requester, if authenticated. For federation requests this is the # The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object. # server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]] self._requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext] self.logcontext = None # type: Optional[LoggingContext]
@ -93,6 +97,31 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
@property
def requester(self) -> Optional[Union[Requester, str]]:
return self._requester
@requester.setter
def requester(self, value: Union[Requester, str]) -> None:
# Store the requester, and update some properties based on it.
# This should only be called once.
assert self._requester is None
self._requester = value
# A logging context should exist by now (and have a ContextRequest).
assert self.logcontext is not None
assert self.logcontext.request is not None
(
requester,
authenticated_entity,
) = self.get_authenticated_entity()
self.logcontext.request.requester = requester
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
@ -126,13 +155,60 @@ class SynapseRequest(Request):
return self.method.decode("ascii") return self.method.decode("ascii")
return method return method
def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
"""
Get the "authenticated" entity of the request, which might be the user
performing the action, or a user being puppeted by a server admin.
Returns:
A tuple:
The first item is a string representing the user making the request.
The second item is a string or None representing the user who
authenticated when making this request. See
Requester.authenticated_entity.
"""
# Convert the requester into a string that we can log
if isinstance(self._requester, str):
return self._requester, None
elif isinstance(self._requester, Requester):
requester = self._requester.user.to_string()
authenticated_entity = self._requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
if self._requester.user.to_string() != authenticated_entity:
return requester, authenticated_entity
return requester, None
elif self._requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
return repr(self._requester), None # type: ignore[unreachable]
return None, None
def render(self, resrc): def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
# create a LogContext for this request # create a LogContext for this request
request_id = self.get_request_id() request_id = self.get_request_id()
self.logcontext = LoggingContext(request_id, request=request_id) self.logcontext = LoggingContext(
request_id,
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
site_tag=self.site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
method=self.get_method(),
url=self.get_redacted_uri(),
protocol=self.clientproto.decode("ascii", errors="replace"),
user_agent=get_request_user_agent(self),
),
)
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.site.server_version_string)
@ -277,25 +353,6 @@ class SynapseRequest(Request):
# to the client (nb may be negative) # to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time response_send_time = self.finish_time - self._processing_finished_time
# Convert the requester into a string that we can log
authenticated_entity = None
if isinstance(self.requester, str):
authenticated_entity = self.requester
elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity,
self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
user_agent = get_request_user_agent(self, "-") user_agent = get_request_user_agent(self, "-")
code = str(self.code) code = str(self.code)
@ -305,6 +362,13 @@ class SynapseRequest(Request):
code += "!" code += "!"
log_level = logging.INFO if self._should_log_request() else logging.DEBUG log_level = logging.INFO if self._should_log_request() else logging.DEBUG
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
requester, authenticated_entity = self.get_authenticated_entity()
if authenticated_entity:
requester = "{}.{}".format(authenticated_entity, requester)
self.site.access_logger.log( self.site.access_logger.log(
log_level, log_level,
"%s - %s - {%s}" "%s - %s - {%s}"
@ -312,7 +376,7 @@ class SynapseRequest(Request):
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
authenticated_entity, requester,
processing_time, processing_time,
response_send_time, response_send_time,
usage.ru_utime, usage.ru_utime,

View file

@ -22,7 +22,6 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import inspect import inspect
import logging import logging
import threading import threading
@ -30,6 +29,7 @@ import types
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer, threads from twisted.internet import defer, threads
@ -181,6 +181,29 @@ class ContextResourceUsage:
return res return res
@attr.s(slots=True)
class ContextRequest:
"""
A bundle of attributes from the SynapseRequest object.
This exists to:
* Avoid a cycle between LoggingContext and SynapseRequest.
* Be a single variable that can be passed from parent LoggingContexts to
their children.
"""
request_id = attr.ib(type=str)
ip_address = attr.ib(type=str)
site_tag = attr.ib(type=str)
requester = attr.ib(type=Optional[str])
authenticated_entity = attr.ib(type=Optional[str])
method = attr.ib(type=str)
url = attr.ib(type=str)
protocol = attr.ib(type=str)
user_agent = attr.ib(type=str)
LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
@ -256,7 +279,7 @@ class LoggingContext:
self, self,
name: Optional[str] = None, name: Optional[str] = None,
parent_context: "Optional[LoggingContext]" = None, parent_context: "Optional[LoggingContext]" = None,
request: Optional[str] = None, request: Optional[ContextRequest] = None,
) -> None: ) -> None:
self.previous_context = current_context() self.previous_context = current_context()
self.name = name self.name = name
@ -281,7 +304,11 @@ class LoggingContext:
self.parent_context = parent_context self.parent_context = parent_context
if self.parent_context is not None: if self.parent_context is not None:
self.parent_context.copy_to(self) # we track the current request_id
self.request = self.parent_context.request
# we also track the current scope:
self.scope = self.parent_context.scope
if request is not None: if request is not None:
# the request param overrides the request from the parent context # the request param overrides the request from the parent context
@ -289,7 +316,7 @@ class LoggingContext:
def __str__(self) -> str: def __str__(self) -> str:
if self.request: if self.request:
return str(self.request) return self.request.request_id
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod @classmethod
@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter):
# we end up in a death spiral of infinite loops, so let's check, for # we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake. # robustness' sake.
if context is not None: if context is not None:
# Logging is interested in the request. # Logging is interested in the request ID. Note that for backwards
record.request = context.request # type: ignore # compatibility this is stored as the "request" on the record.
record.request = str(context) # type: ignore
# Add some data from the HTTP request.
request = context.request
if request is None:
return True
record.ip_address = request.ip_address # type: ignore
record.site_tag = request.site_tag # type: ignore
record.requester = request.requester # type: ignore
record.authenticated_entity = request.authenticated_entity # type: ignore
record.method = request.method # type: ignore
record.url = request.url # type: ignore
record.protocol = request.protocol # type: ignore
record.user_agent = request.user_agent # type: ignore
return True return True
@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
def nested_logging_context(suffix: str) -> LoggingContext: def nested_logging_context(suffix: str) -> LoggingContext:
"""Creates a new logging context as a child of another. """Creates a new logging context as a child of another.
The nested logging context will have a 'request' made up of the parent context's The nested logging context will have a 'name' made up of the parent context's
request, plus the given suffix. name, plus the given suffix.
CPU/db usage stats will be added to the parent context's on exit. CPU/db usage stats will be added to the parent context's on exit.
@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
# ... do stuff # ... do stuff
Args: Args:
suffix: suffix to add to the parent context's 'request'. suffix: suffix to add to the parent context's 'name'.
Returns: Returns:
LoggingContext: new logging context. LoggingContext: new logging context.
@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext:
) )
parent_context = None parent_context = None
prefix = "" prefix = ""
request = None
else: else:
assert isinstance(curr_context, LoggingContext) assert isinstance(curr_context, LoggingContext)
parent_context = curr_context parent_context = curr_context
prefix = str(parent_context.request) prefix = str(parent_context.name)
return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix) request = parent_context.request
return LoggingContext(
prefix + "-" + suffix,
parent_context=parent_context,
request=request,
)
def preserve_fn(f): def preserve_fn(f):

View file

@ -16,7 +16,7 @@
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Set from typing import TYPE_CHECKING, Dict, Optional, Set, Union
from prometheus_client.core import REGISTRY, Counter, Gauge from prometheus_client.core import REGISTRY, Counter, Gauge
@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc() _background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc()
with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: with BackgroundProcessLoggingContext(desc, count) as context:
try: try:
ctx = noop_context_manager() ctx = noop_context_manager()
if bg_start_span: if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request}) ctx = start_active_span(desc, tags={"request_id": str(context)})
with ctx: with ctx:
return await maybe_awaitable(func(*args, **kwargs)) return await maybe_awaitable(func(*args, **kwargs))
except Exception: except Exception:
@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext):
processes. processes.
""" """
__slots__ = ["_proc"] __slots__ = ["_id", "_proc"]
def __init__(self, name: str, request: Optional[str] = None): def __init__(self, name: str, id: Optional[Union[int, str]] = None):
super().__init__(name, request=request) super().__init__(name)
self._id = id
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)
def __str__(self) -> str:
if self._id is not None:
return "%s-%s" % (self.name, self._id)
return "%s@%x" % (self.name, id(self))
def start(self, rusage: "Optional[resource._RUsage]"): def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again).""" """Log context has started running (again)."""

View file

@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id self._logging_context = BackgroundProcessLoggingContext(
self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) "replication-conn", self.conn_id
)
def connectionMade(self): def connectionMade(self):
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())

View file

@ -16,6 +16,7 @@ import time
from mock import Mock from mock import Mock
import attr
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import signedjson.sign import signedjson.sign
@ -68,6 +69,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key) signedjson.sign.sign_json(res, self.server_name, self.key)
@attr.s(slots=True)
class FakeRequest:
id = attr.ib()
@logcontext_clean @logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase): class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected): def check_context(self, val, expected):
@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred() first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch): async def first_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request, "context_11") self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred) await make_deferred_yieldable(first_lookup_deferred)
@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup(): async def first_lookup():
with LoggingContext("context_11") as context_11: with LoggingContext("context_11", request=FakeRequest("context_11")):
context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
) )
@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call # should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch): async def second_lookup_fetch(keys_to_fetch):
self.assertEquals(current_context().request, "context_12") self.assertEquals(current_context().request.id, "context_12")
return { return {
"server10": { "server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0] second_lookup_state = [0]
async def second_lookup(): async def second_lookup():
with LoggingContext("context_12") as context_12: with LoggingContext("context_12", request=FakeRequest("context_12")):
context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")] [("server10", json1, 0, "test")]
) )
@ -589,10 +591,7 @@ def get_key_id(key):
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx: with LoggingContext("testctx"):
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs) rv = yield f(*args, **kwargs)
return rv return rv

View file

@ -12,15 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import logging import logging
from io import StringIO from io import BytesIO, StringIO
from mock import Mock, patch
from twisted.web.server import Request
from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin from tests.logging import LoggerCleanupMixin
from tests.server import FakeChannel
from tests.unittest import TestCase from tests.unittest import TestCase
@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter()) handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler) logger = self.get_logger(handler)
with LoggingContext(request="test"): with LoggingContext("name"):
logger.info("Hello there, %s!", "wally") logger.info("Hello there, %s!", "wally")
log = self.get_log_line() log = self.get_log_line()
@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
] ]
self.assertCountEqual(log.keys(), expected_log_keys) self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test") self.assertTrue(log["request"].startswith("name@"))
def test_with_request_context(self):
"""
Information from the logging context request should be added to the JSON response.
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
# A full request isn't needed here.
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site"
site.server_version_string = "Server v1"
request = SynapseRequest(FakeChannel(site, None))
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest.
request._started_processing = Mock()
request.request_metrics = Mock(spec=["name"])
with patch.object(Request, "render"):
request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
# Also set the requester to ensure the processing works.
request.requester = "@foo:test"
with LoggingContext(parent_context=request.logcontext):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
# The terse logger includes additional request information, if possible.
expected_log_keys = [
"log",
"level",
"namespace",
"request",
"ip_address",
"site_tag",
"requester",
"authenticated_entity",
"method",
"url",
"protocol",
"user_agent",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertTrue(log["request"].startswith("POST-"))
self.assertEqual(log["ip_address"], "127.0.0.1")
self.assertEqual(log["site_tag"], "test-site")
self.assertEqual(log["requester"], "@foo:test")
self.assertEqual(log["authenticated_entity"], "@foo:test")
self.assertEqual(log["method"], "POST")
self.assertEqual(log["url"], "/_matrix/client/versions")
self.assertEqual(log["protocol"], "1.1")
self.assertEqual(log["user_agent"], "")

View file

@ -471,7 +471,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj kwargs["config"] = config_obj
async def run_bg_updates(): async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"): with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates(): while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1) await stor.db_pool.updates.do_next_background_update(1)

View file

@ -661,14 +661,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1") @descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2): async def list_fn(self, args1, arg2):
assert current_context().request == "c1" assert current_context().name == "c1"
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
await run_on_reactor() await run_on_reactor()
assert current_context().request == "c1" assert current_context().name == "c1"
return self.mock(args1, arg2) return self.mock(args1, arg2)
with LoggingContext() as c1: with LoggingContext("c1") as c1:
c1.request = "c1"
obj = Cls() obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"} obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2) d1 = obj.list_fn([10, 20], 2)

View file

@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
self.assertEquals(current_context().request, value) self.assertEquals(current_context().name, value)
def test_with_context(self): def test_with_context(self):
with LoggingContext() as context_one: with LoggingContext("test"):
context_one.request = "test"
self._check_test_key("test") self._check_test_key("test")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def competing_callback(): def competing_callback():
with LoggingContext() as competing_context: with LoggingContext("competing"):
competing_context.request = "competing"
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("competing") self._check_test_key("competing")
reactor.callLater(0, competing_callback) reactor.callLater(0, competing_callback)
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("one") self._check_test_key("one")
@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False] callback_completed = [False]
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
# fire off function, but don't wait on it. # fire off function, but don't wait on it.
d2 = run_in_background(function) d2 = run_in_background(function)
@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self): def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function()) d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)
@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its """Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred""" argument isn't actually a deferred"""
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable("bum") d1 = make_deferred_yieldable("bum")
self._check_test_key("one") self._check_test_key("one")
@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one") self._check_test_key("one")
def test_nested_logging_context(self): def test_nested_logging_context(self):
with LoggingContext(request="foo"): with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar") nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar") self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self): def test_make_deferred_yieldable_with_await(self):
@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext() as context_one: with LoggingContext("one"):
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function()) d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable # make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context) self.assertIs(current_context(), sentinel_context)