0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-02 11:34:02 +01:00

Merge pull request #144 from matrix-org/erikj/logging_context

Preserving logging contexts
This commit is contained in:
Mark Haines 2015-05-12 15:23:50 +01:00
commit a6fb2aa2a5
21 changed files with 236 additions and 153 deletions

View file

@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \

View file

@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
import simplejson as json import simplejson as json
import logging import logging
@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5): for i in range(5):
try: try:
with PreserveLoggingContext(): protocol = yield preserve_context_over_fn(
protocol = yield endpoint.connect(factory) endpoint.connect, factory
server_response, server_certificate = yield protocol.remote_key )
defer.returnValue((server_response, server_certificate)) server_response, server_certificate = yield preserve_context_over_deferred(
return protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"): if e.status.startswith("4"):

View file

@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging import logging
@ -94,7 +96,7 @@ class FederationBase(object):
yield defer.gatherResults( yield defer.gatherResults(
[do(pdu) for pdu in pdus], [do(pdu) for pdu in pdus],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) defer.returnValue(signed_pdus)

View file

@ -20,7 +20,6 @@ from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -123,29 +122,28 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): results = []
results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield d
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
results.append({"error": str(e)}) results.append({"error": str(e)})
except Exception as e: except Exception as e:
results.append({"error": str(e)}) results.append({"error": str(e)})
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu( self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
) )
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)

View file

@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logcontext import PreserveLoggingContext
import logging import logging
@ -137,10 +139,11 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
# Don't block waiting on waking up all the listeners. with PreserveLoggingContext():
notify_d = self.notifier.on_new_room_event( # Don't block waiting on waking up all the listeners.
event, extra_users=extra_users notify_d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(

View file

@ -15,7 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for(
events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout
auth_user, room_ids, pagin_config, timeout )
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -21,6 +21,8 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -199,9 +201,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -566,9 +569,10 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
new_event, extra_users=[joinee] d = self.notifier.on_new_room_event(
) new_event, extra_users=[joinee]
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -647,9 +651,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -729,9 +734,10 @@ class FederationHandler(BaseHandler):
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=[target_user], d = self.notifier.on_new_room_event(
) event, extra_users=[target_user],
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -1056,7 +1062,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View file

@ -20,6 +20,7 @@ from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken from synapse.types import UserID, RoomStreamToken
@ -313,7 +314,7 @@ class MessageHandler(BaseHandler):
event.room_id event.room_id
), ),
] ]
) ).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
@ -338,7 +339,7 @@ class MessageHandler(BaseHandler):
yield defer.gatherResults( yield defer.gatherResults(
[handle_room(e) for e in room_list], [handle_room(e) for e in room_list],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,

View file

@ -18,8 +18,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
import synapse.metrics import synapse.metrics
@ -278,15 +278,14 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
with PreserveLoggingContext(): if now_online and not was_polling:
if now_online and not was_polling: self.start_polling_presence(target_user, state=state)
self.start_polling_presence(target_user, state=state) elif not now_online and was_polling:
elif not now_online and was_polling: self.stop_polling_presence(target_user)
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
self.changed_presencelike_data(target_user, state) self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None): def bump_presence_active_time(self, user, now=None):
if now is None: if now is None:
@ -408,10 +407,10 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext():
self.start_polling_presence( self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user): def deny_presence(self, observed_user, observer_user):
@ -430,10 +429,9 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext(): self.stop_polling_presence(
self.stop_polling_presence( observer_user, target_user=observed_user
observer_user, target_user=observed_user )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
@ -766,8 +764,7 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]: if not self._remote_sendmap[user]:
del self._remote_sendmap[user] del self._remote_sendmap[user]
with PreserveLoggingContext(): yield defer.DeferredList(deferreds, consumeErrors=True)
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
@ -812,10 +809,11 @@ class PresenceHandler(BaseHandler):
def push_update_to_clients(self, observed_user, users_to_push=[], def push_update_to_clients(self, observed_user, users_to_push=[],
room_ids=[], statuscache=None): room_ids=[], statuscache=None):
self.notifier.on_new_user_event( with PreserveLoggingContext():
users_to_push, self.notifier.on_new_user_event(
room_ids, users_to_push,
) room_ids,
)
class PresenceEventSource(object): class PresenceEventSource(object):

View file

@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
defer.returnValue(None) defer.returnValue(None)
with PreserveLoggingContext(): (displayname, avatar_url) = yield defer.gatherResults(
(displayname, avatar_url) = yield defer.gatherResults( [
[ self.store.get_profile_displayname(user.localpart),
self.store.get_profile_displayname(user.localpart), self.store.get_profile_avatar_url(user.localpart),
self.store.get_profile_avatar_url(user.localpart), ],
], consumeErrors=True
consumeErrors=True ).addErrback(unwrapFirstError)
)
state["displayname"] = displayname state["displayname"] = displayname
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url

View file

@ -21,7 +21,7 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -537,7 +537,7 @@ class RoomListHandler(BaseHandler):
for room in chunk for room in chunk
], ],
consumeErrors=True, consumeErrors=True,
) ).addErrback(unwrapFirstError)
for i, room in enumerate(chunk): for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i]) room["num_joined_members"] = len(results[i])

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
import logging import logging
@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler):
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
self.notifier.on_new_user_event(rooms=[room_id]) with PreserveLoggingContext():
self.notifier.on_new_user_event(rooms=[room_id])
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
import synapse.metrics import synapse.metrics
@ -61,7 +62,10 @@ class SimpleHttpClient(object):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = self.agent.request(method, *args, **kwargs) d = preserve_context_over_fn(
self.agent.request,
method, *args, **kwargs
)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)

View file

@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics import synapse.metrics
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, url_bytes, headers_dict)
try: try:
with PreserveLoggingContext(): request_deferred = preserve_context_over_fn(
request_deferred = self.agent.request( self.agent.request,
destination, destination,
endpoint, endpoint,
method, method,
path_bytes, path_bytes,
param_bytes, param_bytes,
query_bytes, query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
response = yield self.clock.time_bound_deferred( response = yield self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=60, time_out=60,
) )
logger.debug("Got response to %s", method) logger.debug("Got response to %s", method)
break break

View file

@ -17,7 +17,7 @@
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics import synapse.metrics
from syutil.jsonutil import ( from syutil.jsonutil import (
@ -85,7 +85,9 @@ def request_handler(request_handler):
"Received request: %s %s", "Received request: %s %s",
request.method, request.path request.method, request.path
) )
yield request_handler(self, request) d = request_handler(self, request)
with PreserveLoggingContext():
yield d
code = request.code code = request.code
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code

View file

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -223,11 +222,10 @@ class Notifier(object):
def eb(failure): def eb(failure):
logger.exception("Failed to notify listener", failure) logger.exception("Failed to notify listener", failure)
with PreserveLoggingContext(): yield defer.DeferredList(
yield defer.DeferredList( [notify(l).addErrback(eb) for l in listeners],
[notify(l).addErrback(eb) for l in listeners], consumeErrors=True,
consumeErrors=True, )
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -298,11 +296,10 @@ class Notifier(object):
failure.getTracebackObject()) failure.getTracebackObject())
) )
with PreserveLoggingContext(): yield defer.DeferredList(
yield defer.DeferredList( [notify(l).addErrback(eb) for l in listeners],
[notify(l).addErrback(eb) for l in listeners], consumeErrors=True,
consumeErrors=True, )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback): def wait_for_events(self, user, rooms, filter, timeout, callback):

View file

@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
import synapse.metrics import synapse.metrics
@ -420,10 +420,11 @@ class SQLBaseStore(object):
self._txn_perf_counters.update(desc, start, end) self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc) sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext(): result = yield preserve_context_over_fn(
result = yield self._db_pool.runWithConnection( self._db_pool.runWithConnection,
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -23,6 +23,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
return failure.value.subFailure
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
@ -50,9 +56,12 @@ class Clock(object):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback():
LoggingContext.thread_local.current_context = current_context with PreserveLoggingContext():
callback() LoggingContext.thread_local.current_context = current_context
return reactor.callLater(delay, wrapped_callback) callback()
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback)
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
timer.cancel() timer.cancel()

View file

@ -16,15 +16,13 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import preserve_context_over_deferred
@defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds) reactor.callLater(seconds, d.callback, seconds)
with PreserveLoggingContext(): return preserve_context_over_deferred(d)
yield d
def run_on_reactor(): def run_on_reactor():

View file

@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError
import logging import logging
@ -93,7 +97,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -101,24 +104,28 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
with PreserveLoggingContext():
deferreds = []
for observer in self.observers:
d = defer.maybeDeferred(observer, *args, **kwargs)
def eb(failure): def do(observer):
logger.warning( def eb(failure):
"%s signal observer %s failed: %r", logger.warning(
self.name, observer, failure, "%s signal observer %s failed: %r",
exc_info=( self.name, observer, failure,
failure.type, exc_info=(
failure.value, failure.type,
failure.getTracebackObject())) failure.value,
if not self.suppress_failures: failure.getTracebackObject()))
failure.raiseException() if not self.suppress_failures:
deferreds.append(d.addErrback(eb)) return failure
results = [] return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
for deferred in deferreds:
result = yield deferred with PreserveLoggingContext():
results.append(result) deferreds = [
defer.returnValue(results) do(observer)
for observer in self.observers
]
d = defer.gatherResults(deferreds, consumeErrors=True)
d.addErrback(unwrapFirstError)
return preserve_context_over_deferred(d)

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import threading import threading
import logging import logging
@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context LoggingContext.thread_local.current_context = self.current_context
if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None:
logger.warn(
"Restoring dead context: %s",
self.current_context,
)
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
If the result is a deferred, call preserve_context_over_deferred before
returning it.
"""
with PreserveLoggingContext():
res = fn(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(res)
else:
return res
def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
d = defer.Deferred()
current_context = LoggingContext.current_context()
def cb(res):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.callback(res)
return res
def eb(failure):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.errback(failure)
return res
if deferred.called:
return deferred
deferred.addCallbacks(cb, eb)
return d