0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 15:01:23 +01:00

Merge branch 'notifier_unify' into notifier_performance

Conflicts:
	synapse/notifier.py
This commit is contained in:
Mark Haines 2015-05-12 16:37:50 +01:00
commit cffe6057fb
28 changed files with 628 additions and 303 deletions

View file

@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \

View file

@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
import simplejson as json import simplejson as json
import logging import logging
@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5): for i in range(5):
try: try:
with PreserveLoggingContext(): protocol = yield preserve_context_over_fn(
protocol = yield endpoint.connect(factory) endpoint.connect, factory
server_response, server_certificate = yield protocol.remote_key )
defer.returnValue((server_response, server_certificate)) server_response, server_certificate = yield preserve_context_over_deferred(
return protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e: except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,)) logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"): if e.status.startswith("4"):

View file

@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import create_observer from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
@ -111,6 +111,10 @@ class Keyring(object):
if download is None: if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids) download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.key_downloads[server_name] = download self.key_downloads[server_name] = download
@download.addBoth @download.addBoth
@ -118,7 +122,7 @@ class Keyring(object):
del self.key_downloads[server_name] del self.key_downloads[server_name]
return ret return ret
r = yield create_observer(download) r = yield download.observe()
defer.returnValue(r) defer.returnValue(r)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging import logging
@ -94,7 +96,7 @@ class FederationBase(object):
yield defer.gatherResults( yield defer.gatherResults(
[do(pdu) for pdu in pdus], [do(pdu) for pdu in pdus],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) defer.returnValue(signed_pdus)

View file

@ -20,7 +20,6 @@ from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -123,29 +122,28 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): results = []
results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
try: try:
yield d yield d
results.append({}) results.append({})
except FederationError as e: except FederationError as e:
self.send_failure(e, transaction.origin) self.send_failure(e, transaction.origin)
results.append({"error": str(e)}) results.append({"error": str(e)})
except Exception as e: except Exception as e:
results.append({"error": str(e)}) results.append({"error": str(e)})
logger.exception("Failed to handle PDU") logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu( self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
) )
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)

View file

@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logcontext import PreserveLoggingContext
import logging import logging
@ -137,10 +139,11 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
# Don't block waiting on waking up all the listeners. with PreserveLoggingContext():
notify_d = self.notifier.on_new_room_event( # Don't block waiting on waking up all the listeners.
event, extra_users=extra_users notify_d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(

View file

@ -15,7 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for(
events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout
auth_user, room_ids, pagin_config, timeout )
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -18,9 +18,11 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, StoreError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -29,6 +31,8 @@ from synapse.crypto.event_signing import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer from twisted.internet import defer
import itertools import itertools
@ -197,9 +201,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -218,10 +223,11 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
""" """
extremities = yield self.store.get_oldest_events_in_room(room_id) if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
pdus = yield self.replication_layer.backfill( pdus = yield self.replication_layer.backfill(
dest, dest,
@ -248,6 +254,138 @@ class FederationHandler(BaseHandler):
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id
)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
return
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),
key=lambda e: -int(e[1])
)
max_depth = sorted_extremeties_tuple[0][1]
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
max_depth, current_depth,
)
return
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
joined_domains = {}
for u, d in joined_users:
try:
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains
]
@defer.inlineCallbacks
def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
except SynapseError:
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.warn(
"Failed to backfill from %s because %s",
dom, e,
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
if success:
defer.returnValue(True)
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
tried_domains = set(likely_domains)
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups([e])
for e in event_ids
])
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, target_host, event): def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing. """ Sends the invite to the remote server for signing.
@ -431,9 +569,10 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
new_event, extra_users=[joinee] d = self.notifier.on_new_room_event(
) new_event, extra_users=[joinee]
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -512,9 +651,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=extra_users d = self.notifier.on_new_room_event(
) event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -594,9 +734,10 @@ class FederationHandler(BaseHandler):
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
d = self.notifier.on_new_room_event( with PreserveLoggingContext():
event, extra_users=[target_user], d = self.notifier.on_new_room_event(
) event, extra_users=[target_user],
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -921,7 +1062,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View file

@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID, RoomStreamToken
from ._base import BaseHandler from ._base import BaseHandler
@ -89,9 +90,19 @@ class MessageHandler(BaseHandler):
if not pagin_config.from_token: if not pagin_config.from_token:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token() yield self.hs.get_event_sources().get_current_token(
direction='b'
)
) )
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
@ -303,7 +314,7 @@ class MessageHandler(BaseHandler):
event.room_id event.room_id
), ),
] ]
) ).addErrback(unwrapFirstError)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
@ -328,7 +339,7 @@ class MessageHandler(BaseHandler):
yield defer.gatherResults( yield defer.gatherResults(
[handle_room(e) for e in room_list], [handle_room(e) for e in room_list],
consumeErrors=True consumeErrors=True
) ).addErrback(unwrapFirstError)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,

View file

@ -18,14 +18,15 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
import synapse.metrics import synapse.metrics
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
from collections import OrderedDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -143,7 +144,7 @@ class PresenceHandler(BaseHandler):
self._remote_offline_serials = [] self._remote_offline_serials = []
# map any user to a UserPresenceCache # map any user to a UserPresenceCache
self._user_cachemap = {} self._user_cachemap = OrderedDict() # keep them sorted by serial
self._user_cachemap_latest_serial = 0 self._user_cachemap_latest_serial = 0
metrics.register_callback( metrics.register_callback(
@ -165,6 +166,14 @@ class PresenceHandler(BaseHandler):
else: else:
return UserPresenceCache() return UserPresenceCache()
def _bump_serial(self, user=None):
self._user_cachemap_latest_serial += 1
if user:
# Move to end
cache = self._user_cachemap.pop(user)
self._user_cachemap[user] = cache
def registered_user(self, user): def registered_user(self, user):
return self.store.create_presence(user.localpart) return self.store.create_presence(user.localpart)
@ -278,15 +287,14 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
with PreserveLoggingContext(): if now_online and not was_polling:
if now_online and not was_polling: self.start_polling_presence(target_user, state=state)
self.start_polling_presence(target_user, state=state) elif not now_online and was_polling:
elif not now_online and was_polling: self.stop_polling_presence(target_user)
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
self.changed_presencelike_data(target_user, state) self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None): def bump_presence_active_time(self, user, now=None):
if now is None: if now is None:
@ -301,7 +309,7 @@ class PresenceHandler(BaseHandler):
def changed_presencelike_data(self, user, state): def changed_presencelike_data(self, user, state):
statuscache = self._get_or_make_usercache(user) statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1 self._bump_serial(user=user)
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache.update(state, serial=self._user_cachemap_latest_serial)
return self.push_presence(user, statuscache=statuscache) return self.push_presence(user, statuscache=statuscache)
@ -323,7 +331,7 @@ class PresenceHandler(BaseHandler):
# No actual update but we need to bump the serial anyway for the # No actual update but we need to bump the serial anyway for the
# event source # event source
self._user_cachemap_latest_serial += 1 self._bump_serial()
statuscache.update({}, serial=self._user_cachemap_latest_serial) statuscache.update({}, serial=self._user_cachemap_latest_serial)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
@ -408,10 +416,10 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext():
self.start_polling_presence( self.start_polling_presence(
observer_user, target_user=observed_user observer_user, target_user=observed_user
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user): def deny_presence(self, observed_user, observer_user):
@ -430,10 +438,9 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext(): self.stop_polling_presence(
self.stop_polling_presence( observer_user, target_user=observed_user
observer_user, target_user=observed_user )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
@ -706,7 +713,7 @@ class PresenceHandler(BaseHandler):
statuscache = self._get_or_make_usercache(user) statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1 self._bump_serial(user=user)
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache.update(state, serial=self._user_cachemap_latest_serial)
if not observers and not room_ids: if not observers and not room_ids:
@ -766,8 +773,7 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]: if not self._remote_sendmap[user]:
del self._remote_sendmap[user] del self._remote_sendmap[user]
with PreserveLoggingContext(): yield defer.DeferredList(deferreds, consumeErrors=True)
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
@ -812,10 +818,11 @@ class PresenceHandler(BaseHandler):
def push_update_to_clients(self, observed_user, users_to_push=[], def push_update_to_clients(self, observed_user, users_to_push=[],
room_ids=[], statuscache=None): room_ids=[], statuscache=None):
self.notifier.on_new_user_event( with PreserveLoggingContext():
users_to_push, self.notifier.on_new_user_event(
room_ids, users_to_push,
) room_ids,
)
class PresenceEventSource(object): class PresenceEventSource(object):
@ -866,10 +873,15 @@ class PresenceEventSource(object):
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. # TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys(): for observed_user in reversed(cachemap.keys()):
cached = cachemap[observed_user] cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial: # Since this is ordered in descending order of serial, we can just
# stop once we've seen enough
if cached.serial <= from_key:
break
if cached.serial > max_serial:
continue continue
if not (yield self.is_visible(observer_user, observed_user)): if not (yield self.is_visible(observer_user, observed_user)):

View file

@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
defer.returnValue(None) defer.returnValue(None)
with PreserveLoggingContext(): (displayname, avatar_url) = yield defer.gatherResults(
(displayname, avatar_url) = yield defer.gatherResults( [
[ self.store.get_profile_displayname(user.localpart),
self.store.get_profile_displayname(user.localpart), self.store.get_profile_avatar_url(user.localpart),
self.store.get_profile_avatar_url(user.localpart), ],
], consumeErrors=True
consumeErrors=True ).addErrback(unwrapFirstError)
)
state["displayname"] = displayname state["displayname"] = displayname
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url

View file

@ -21,7 +21,7 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -537,7 +537,7 @@ class RoomListHandler(BaseHandler):
for room in chunk for room in chunk
], ],
consumeErrors=True, consumeErrors=True,
) ).addErrback(unwrapFirstError)
for i, room in enumerate(chunk): for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i]) room["num_joined_members"] = len(results[i])
@ -577,8 +577,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self): def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id() return self.store.get_room_events_max_id(direction)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
import logging import logging
@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler):
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
self.notifier.on_new_user_event(rooms=[room_id]) with PreserveLoggingContext():
self.notifier.on_new_user_event(rooms=[room_id])
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
import synapse.metrics import synapse.metrics
@ -61,7 +62,10 @@ class SimpleHttpClient(object):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = self.agent.request(method, *args, **kwargs) d = preserve_context_over_fn(
self.agent.request,
method, *args, **kwargs
)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)

View file

@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics import synapse.metrics
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, url_bytes, headers_dict)
try: try:
with PreserveLoggingContext(): request_deferred = preserve_context_over_fn(
request_deferred = self.agent.request( self.agent.request,
destination, destination,
endpoint, endpoint,
method, method,
path_bytes, path_bytes,
param_bytes, param_bytes,
query_bytes, query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
response = yield self.clock.time_bound_deferred( response = yield self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=60, time_out=60,
) )
logger.debug("Got response to %s", method) logger.debug("Got response to %s", method)
break break

View file

@ -17,7 +17,7 @@
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics import synapse.metrics
from syutil.jsonutil import ( from syutil.jsonutil import (
@ -85,7 +85,9 @@ def request_handler(request_handler):
"Received request: %s %s", "Received request: %s %s",
request.method, request.path request.method, request.path
) )
yield request_handler(self, request) d = request_handler(self, request)
with PreserveLoggingContext():
yield d
code = request.code code = request.code
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code

View file

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -196,12 +195,11 @@ class Notifier(object):
logger.debug("on_new_room_event listeners %s", user_streams) logger.debug("on_new_room_event listeners %s", user_streams)
with PreserveLoggingContext(): for user_stream in user_streams:
for user_stream in user_streams: try:
try: user_stream.notify(new_token)
user_stream.notify(new_token) except:
except: logger.exception("Failed to notify listener")
logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -223,12 +221,11 @@ class Notifier(object):
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
with PreserveLoggingContext(): for user_stream in user_streams:
for user_stream in user_streams: try:
try: user_streams.notify(new_token)
user_streams.notify(new_token) except:
except: logger.exception("Failed to notify listener")
logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, rooms, timeout, callback,

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util.async import create_observer from synapse.util.async import ObservableDeferred
import os import os
@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
download = self.downloads.get(key) download = self.downloads.get(key)
if download is None: if download is None:
download = self._get_remote_media_impl(server_name, media_id) download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download self.downloads[key] = download
@download.addBoth @download.addBoth
def callback(media_info): def callback(media_info):
del self.downloads[key] del self.downloads[key]
return media_info return media_info
return create_observer(download) return download.observe()
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id): def _get_remote_media_impl(self, server_name, media_id):

View file

@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
import synapse.metrics import synapse.metrics
@ -420,10 +420,11 @@ class SQLBaseStore(object):
self._txn_perf_counters.update(desc, start, end) self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc) sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext(): result = yield preserve_context_over_fn(
result = yield self._db_pool.runWithConnection( self._db_pool.runWithConnection,
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
for after_callback, after_args in after_callbacks: for after_callback, after_args in after_callbacks:
after_callback(*after_args) after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)

View file

@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore):
room_id, room_id,
) )
def get_oldest_events_with_depth_in_room(self, room_id):
return self.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
)
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
sql = (
"SELECT b.event_id, MAX(e.depth) FROM events as e"
" INNER JOIN event_edges as g"
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
" INNER JOIN event_backward_extremities as b"
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
" WHERE b.room_id = ? AND g.is_state is ?"
" GROUP BY b.event_id"
)
txn.execute(sql, (room_id, False,))
return dict(txn.fetchall())
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,
@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore):
do_insert = depth < min_depth if min_depth else True do_insert = depth < min_depth if min_depth else True
if do_insert: if do_insert:
self._simple_insert_txn( self._simple_upsert_txn(
txn, txn,
table="room_depth", table="room_depth",
values={ keyvalues={
"room_id": room_id, "room_id": room_id,
},
values={
"min_depth": depth, "min_depth": depth,
}, },
) )
@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (event_id, room_id)) txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get query = (
# deleted in a second if they're incorrect anyway. "INSERT INTO event_backward_extremities (event_id, room_id)"
self._simple_insert_many_txn( " SELECT ?, ? WHERE NOT EXISTS ("
txn, " SELECT 1 FROM event_backward_extremities"
table="event_backward_extremities", " WHERE event_id = ? AND room_id = ?"
values=[ " )"
{ " AND NOT EXISTS ("
"event_id": e_id, " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
"room_id": room_id, " )"
}
for e_id, _ in prev_events
],
) )
# Also delete from the backwards extremities table all ones that txn.executemany(query, [
# reference events that we have already seen (e_id, room_id, e_id, room_id, e_id, room_id, )
for e_id, _ in prev_events
])
query = ( query = (
"DELETE FROM event_backward_extremities WHERE EXISTS (" "DELETE FROM event_backward_extremities"
"SELECT 1 FROM events " " WHERE event_id = ? AND room_id = ?"
"WHERE "
"event_backward_extremities.event_id = events.event_id "
"AND not events.outlier "
")"
) )
txn.execute(query) txn.execute(query, (event_id, room_id))
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, room_id

View file

@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
if not outlier: if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn( self._update_min_depth_for_room_txn(
txn, txn,
event.room_id, event.room_id,
event.depth event.depth
) )
have_persisted = self._simple_select_one_onecol_txn( have_persisted = self._simple_select_one_txn(
txn, txn,
table="event_json", table="events",
keyvalues={"event_id": event.event_id}, keyvalues={"event_id": event.event_id},
retcol="event_id", retcols=["event_id", "outlier"],
allow_none=True, allow_none=True,
) )
@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore):
# if we are persisting an event that we had persisted as an outlier, # if we are persisting an event that we had persisted as an outlier,
# but is no longer one. # but is no longer one.
if have_persisted: if have_persisted:
if not outlier: if not outlier and have_persisted["outlier"]:
self._store_state_groups_txn(txn, event, context)
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore):
) )
return return
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._handle_prev_events( self._handle_prev_events(
txn, txn,
outlier=outlier, outlier=outlier,

View file

@ -37,11 +37,9 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from collections import namedtuple
import logging import logging
@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
class _StreamToken(namedtuple("_StreamToken", "topological stream")): def lower_bound(token):
"""Tokens are positions between events. The token "s1" comes after event 1. if token.topological is None:
return "(%d < %s)" % (token.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, "stream_ordering",
)
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going def upper_bound(token):
through historic events. if token.topological is None:
return "(%d >= %s)" % (token.stream, "stream_ordering")
When traversing the live event stream events are ordered by when they else:
arrived at the homeserver. return "(%d > %s OR (%d = %s AND %d >= %s))" % (
token.topological, "topological_ordering",
When traversing historic events the events are ordered by their depth in token.topological, "topological_ordering",
the event graph "topological_ordering" and then by when they arrived at the token.stream, "stream_ordering",
homeserver "stream_ordering". )
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@classmethod
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
def lower_bound(self):
if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
def upper_bound(self):
if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering")
else:
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering. # From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key) from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key) to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering. # From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key) from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key) to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key: if from_key == to_key:
return defer.succeed(([], to_key)) return defer.succeed(([], to_key))
@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
args = [False, room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = _StreamToken.parse(from_key).upper_bound() bounds = upper_bound(RoomStreamToken.parse(from_key))
if to_key: if to_key:
bounds = "%s AND %s" % ( bounds = "%s AND %s" % (
bounds, _StreamToken.parse(to_key).lower_bound() bounds, lower_bound(RoomStreamToken.parse(to_key))
) )
else: else:
order = "ASC" order = "ASC"
bounds = _StreamToken.parse(from_key).lower_bound() bounds = lower_bound(RoomStreamToken.parse(from_key))
if to_key: if to_key:
bounds = "%s AND %s" % ( bounds = "%s AND %s" % (
bounds, _StreamToken.parse(to_key).upper_bound() bounds, upper_bound(RoomStreamToken.parse(to_key))
) )
if int(limit) > 0: if int(limit) > 0:
@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore):
# when we are going backwards so we subtract one from the # when we are going backwards so we subtract one from the
# stream part. # stream part.
toke -= 1 toke -= 1
next_token = str(_StreamToken(topo, toke)) next_token = str(RoomStreamToken(topo, toke))
else: else:
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key next_token = to_key if to_key else from_key
@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore):
with_feedback=False, from_token=None): with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
end_token = _StreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)
if from_token is None: if from_token is None:
sql = ( sql = (
@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore):
" LIMIT ?" " LIMIT ?"
) )
else: else:
from_token = _StreamToken.parse_stream_token(from_token) from_token = RoomStreamToken.parse_stream_token(from_token)
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore):
# stream part. # stream part.
topo = rows[0]["topological_ordering"] topo = rows[0]["topological_ordering"]
toke = rows[0]["stream_ordering"] - 1 toke = rows[0]["stream_ordering"] - 1
start_token = str(_StreamToken(topo, toke)) start_token = str(RoomStreamToken(topo, toke))
token = (start_token, str(end_token)) token = (start_token, str(end_token))
else: else:
@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self) token = yield self._stream_id_gen.get_max_token(self)
defer.returnValue("s%d" % (token,)) if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn
)
defer.returnValue("t%d-%d" % (topo, token))
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
" WHERE outlier = ?",
(False,)
)
rows = txn.fetchall()
return rows[0][0] if rows else 0
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore):
stream = row["stream_ordering"] stream = row["stream_ordering"]
topo = event.depth topo = event.depth
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(_StreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(_StreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))

View file

@ -31,7 +31,7 @@ class NullSource(object):
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self): def get_current_key(self, direction='f'):
return defer.succeed(0) return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
@ -52,10 +52,10 @@ class EventSources(object):
} }
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self): def get_current_token(self, direction='f'):
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key() yield self.sources["room"].get_current_key(direction)
), ),
presence_key=( presence_key=(
yield self.sources["presence"].get_current_key() yield self.sources["presence"].get_current_key()

View file

@ -121,4 +121,56 @@ class StreamToken(
return StreamToken(**d) return StreamToken(**d)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@classmethod
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -23,6 +23,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
return failure.value.subFailure
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
@ -50,9 +56,12 @@ class Clock(object):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback():
LoggingContext.thread_local.current_context = current_context with PreserveLoggingContext():
callback() LoggingContext.thread_local.current_context = current_context
return reactor.callLater(delay, wrapped_callback) callback()
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback)
def cancel_call_later(self, timer): def cancel_call_later(self, timer):
timer.cancel() timer.cancel()

View file

@ -16,15 +16,13 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import preserve_context_over_deferred
@defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds) reactor.callLater(seconds, d.callback, seconds)
with PreserveLoggingContext(): return preserve_context_over_deferred(d)
yield d
def run_on_reactor(): def run_on_reactor():
@ -34,20 +32,56 @@ def run_on_reactor():
return sleep(0) return sleep(0)
def create_observer(deferred): class ObservableDeferred(object):
"""Creates a deferred that observes the result or failure of the given """Wraps a deferred object so that we can add observer deferreds. These
deferred *without* affecting the given deferred. observer deferreds do not affect the callback chain of the original
deferred.
If consumeErrors is true errors will be captured from the origin deferred.
""" """
d = defer.Deferred()
def callback(r): __slots__ = ["_deferred", "_observers", "_result"]
d.callback(r)
return r
def errback(f): def __init__(self, deferred, consumeErrors=False):
d.errback(f) object.__setattr__(self, "_deferred", deferred)
return f object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
deferred.addCallbacks(callback, errback) def callback(r):
self._result = (True, r)
while self._observers:
try:
self._observers.pop().callback(r)
except:
pass
return r
return d def errback(f):
self._result = (False, f)
while self._observers:
try:
self._observers.pop().errback(f)
except:
pass
if consumeErrors:
return None
else:
return f
deferred.addCallbacks(callback, errback)
def observe(self):
if not self._result:
d = defer.Deferred()
self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
def __getattr__(self, name):
return getattr(self._deferred, name)
def __setattr__(self, name, value):
setattr(self._deferred, name, value)

View file

@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred,
)
from synapse.util import unwrapFirstError
import logging import logging
@ -93,7 +97,6 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -101,24 +104,28 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
with PreserveLoggingContext():
deferreds = []
for observer in self.observers:
d = defer.maybeDeferred(observer, *args, **kwargs)
def eb(failure): def do(observer):
logger.warning( def eb(failure):
"%s signal observer %s failed: %r", logger.warning(
self.name, observer, failure, "%s signal observer %s failed: %r",
exc_info=( self.name, observer, failure,
failure.type, exc_info=(
failure.value, failure.type,
failure.getTracebackObject())) failure.value,
if not self.suppress_failures: failure.getTracebackObject()))
failure.raiseException() if not self.suppress_failures:
deferreds.append(d.addErrback(eb)) return failure
results = [] return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
for deferred in deferreds:
result = yield deferred with PreserveLoggingContext():
results.append(result) deferreds = [
defer.returnValue(results) do(observer)
for observer in self.observers
]
d = defer.gatherResults(deferreds, consumeErrors=True)
d.addErrback(unwrapFirstError)
return preserve_context_over_deferred(d)

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import threading import threading
import logging import logging
@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context LoggingContext.thread_local.current_context = self.current_context
if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None:
logger.warn(
"Restoring dead context: %s",
self.current_context,
)
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
If the result is a deferred, call preserve_context_over_deferred before
returning it.
"""
with PreserveLoggingContext():
res = fn(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(res)
else:
return res
def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
d = defer.Deferred()
current_context = LoggingContext.current_context()
def cb(res):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.callback(res)
return res
def eb(failure):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.errback(failure)
return res
if deferred.called:
return deferred
deferred.addCallbacks(cb, eb)
return d