0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 17:13:50 +01:00

Merge branch 'develop' into markjh/direct_to_device

This commit is contained in:
Mark Haines 2016-08-25 18:34:46 +01:00
commit ab34fdecb7
22 changed files with 238 additions and 103 deletions

View file

@ -1,3 +1,50 @@
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08) Changes in synapse v0.17.0 (2016-08-08)
======================================= =======================================

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.0" __version__ = "0.17.1"

View file

@ -88,6 +88,8 @@ class ApplicationService(object):
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)
else: else:

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyEntityKind from synapse.types import ThirdPartyEntityKind
import logging import logging
@ -25,6 +26,9 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -56,6 +60,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
@ -97,10 +103,10 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) uri = "%s/thirdparty/user/%s" % (service.url, urllib.quote(protocol))
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) uri = "%s/thirdparty/location/%s" % (service.url, urllib.quote(protocol))
required_field = "alias" required_field = "alias"
else: else:
raise ValueError( raise ValueError(
@ -131,6 +137,22 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
@defer.inlineCallbacks
def _get():
uri = "%s/thirdparty/protocol/%s" % (service.url, urllib.quote(protocol))
try:
defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue({})
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)

View file

@ -150,12 +150,12 @@ class _TransactionController(object):
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = yield txn.send(self.as_api)
if sent: if sent:
txn.complete(self.store) yield txn.complete(self.store)
else: else:
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
self._start_recoverer(service) preserve_fn(self._start_recoverer)(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View file

@ -308,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_server_verify_keys( preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@ -337,13 +337,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(p_name, p_keys) preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@ -383,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
get_key(server_name, key_ids) preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
merged = {} merged = {}
for result in results: for result in results:
@ -466,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store_keys( preserve_fn(self.store_keys)(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@ -476,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -524,7 +524,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@ -534,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@ -600,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@ -613,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
results[server_name] = response_keys results[server_name] = response_keys
@ -702,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults( yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@ -710,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)

View file

@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -102,10 +103,10 @@ class FederationBase(object):
warn, pdu warn, pdu
) )
valid_pdus = yield defer.gatherResults( valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@ -129,7 +130,7 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = self.keyring.verify_json_objects_for_server([ deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])

View file

@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -225,10 +226,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -457,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size]) batch = set(missing_events[i:i + batch_size])
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id in batch for e_id in batch
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res: for success, result in res:
if success: if success:
signed_events.append(result) signed_events.append(result)
@ -853,14 +856,16 @@ class FederationClient(FederationBase):
return srvs return srvs
deferreds = [ deferreds = [
self.get_pdu( preserve_fn(self.get_pdu)(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id, depth in ordered_missing[:limit - len(signed_events)] for e_id, depth in ordered_missing[:limit - len(signed_events)]
] ]
res = yield defer.DeferredList(deferreds, consumeErrors=True) res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val: if result and val:
signed_events.append(val) signed_events.append(val)

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -163,10 +163,10 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([ results = yield preserve_context_over_deferred(defer.DeferredList([
self.appservice_api.query_3pe(service, kind, protocol, fields) preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services for service in services
], consumeErrors=True) ], consumeErrors=True))
ret = [] ret = []
for (success, result) in results: for (success, result) in results:
@ -175,6 +175,16 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self):
services = yield self.store.get_app_services()
protocols = {}
for s in services:
for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
defer.returnValue(protocols)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.

View file

@ -26,7 +26,9 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -361,9 +363,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch missing_auth - failed_to_fetch
) )
results = yield defer.gatherResults( results = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.replication_layer.get_pdu( preserve_fn(self.replication_layer.get_pdu)(
[dest], [dest],
event_id, event_id,
outlier=True, outlier=True,
@ -372,7 +374,7 @@ class FederationHandler(BaseHandler):
for event_id in missing_auth - failed_to_fetch for event_id in missing_auth - failed_to_fetch
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events if event a_id for event in results for a_id, _ in event.auth_events if event
@ -552,10 +554,10 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield defer.gatherResults([ states = yield preserve_context_over_deferred(defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
]) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1166,9 +1168,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield defer.gatherResults( contexts = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self._prep_event( preserve_fn(self._prep_event)(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@ -1176,7 +1178,7 @@ class FederationHandler(BaseHandler):
) )
for ev_info in event_infos for ev_info in event_infos
] ]
) ))
yield self.store.persist_events( yield self.store.persist_events(
[ [
@ -1460,9 +1462,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults( different_events = yield preserve_context_over_deferred(defer.gatherResults(
[ [
self.store.get_event( preserve_fn(self.store.get_event)(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
@ -1471,7 +1473,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)

View file

@ -28,7 +28,8 @@ from synapse.types import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -502,15 +503,17 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id] lambda states: states[event.event_id]
) )
(messages, token), current_state = yield defer.gatherResults( (messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[ [
self.store.get_recent_events_for_room( preserve_fn(self.store.get_recent_events_for_room)(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=room_end_token, end_token=room_end_token,
), ),
deferred_room_state, deferred_room_state,
] ]
)
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield filter_events_for_client( messages = yield filter_events_for_client(
@ -719,9 +722,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), preserve_fn(get_presence)(),
get_receipts(), preserve_fn(get_receipts)(),
self.store.get_recent_events_for_room( preserve_fn(self.store.get_recent_events_for_room)(
room_id, room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=now_token.room_key,
@ -755,6 +758,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
@ -806,6 +810,7 @@ class MessageHandler(BaseHandler):
(event, context,) (event, context,)
) )
@measure_func("handle_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event( def handle_new_client_event(
self, self,
@ -934,7 +939,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
@ -944,6 +949,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )

View file

@ -16,7 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID
@ -169,13 +171,13 @@ class TypingHandler(object):
deferreds = [] deferreds = []
for domain in domains: for domain in domains:
if domain == self.server_name: if domain == self.server_name:
self._push_update_local( preserve_fn(self._push_update_local)(
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
typing=typing typing=typing
) )
else: else:
deferreds.append(self.federation.send_edu( deferreds.append(preserve_fn(self.federation.send_edu)(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
@ -185,7 +187,9 @@ class TypingHandler(object):
}, },
)) ))
yield defer.DeferredList(deferreds, consumeErrors=True) yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):

View file

@ -19,7 +19,7 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -174,6 +174,7 @@ class Notifier(object):
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@ -195,6 +196,7 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@ -212,6 +214,7 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
@ -226,6 +229,7 @@ class Notifier(object):
rooms=[event.room_id], rooms=[event.room_id],
) )
@preserve_fn
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
@ -252,6 +256,7 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""

View file

@ -17,14 +17,15 @@ from twisted.internet import defer
from synapse.util.presentable_names import ( from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([ invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
store.get_invited_rooms_for_user(user_id), preserve_fn(store.get_invited_rooms_for_user)(user_id),
store.get_rooms_for_user(user_id), preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True) ], consumeErrors=True))
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
import pusher import pusher
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
import logging import logging
@ -130,10 +130,12 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id) preserve_fn(p.on_new_notifications)(
min_stream_id, max_stream_id
)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@ -155,10 +157,10 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id) preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
) )
yield defer.gatherResults(deferreds) yield preserve_context_over_deferred(defer.gatherResults(deferreds))
except: except:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View file

@ -403,10 +403,9 @@ class RegisterRestServlet(RestServlet):
# register the user's device # register the user's device
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered( return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
return device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self):

View file

@ -25,8 +25,25 @@ from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -20,7 +20,9 @@ from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -202,7 +204,7 @@ class EventsStore(SQLBaseStore):
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None, current_state=None,
@ -212,7 +214,9 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned.keys(): for room_id in partitioned.keys():
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return defer.gatherResults(deferreds, consumeErrors=True) return preserve_context_over_deferred(
defer.gatherResults(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -225,7 +229,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield deferred yield preserve_context_over_deferred(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token() max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@ -1088,7 +1092,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self._get_event_from_row)( preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"], row["internal_metadata"], row["json"], row["redacts"],
@ -1097,7 +1101,7 @@ class EventsStore(SQLBaseStore):
for row in rows for row in rows
], ],
consumeErrors=True consumeErrors=True
) ))
defer.returnValue({ defer.returnValue({
e.event.event_id: e e.event.event_id: e

View file

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
@ -234,12 +234,12 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order, room_id, from_key, to_key, limit, order=order,
) )
for room_id in rm_ids for room_id in rm_ids
]) ]))
results.update(dict(zip(rm_ids, res))) results.update(dict(zip(rm_ids, res)))
defer.returnValue(results) defer.returnValue(results)

View file

@ -146,10 +146,10 @@ def concurrently_execute(func, args, limit):
except StopIteration: except StopIteration:
pass pass
return defer.gatherResults([ return preserve_context_over_deferred(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)() preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit) for _ in xrange(limit)
], consumeErrors=True).addErrback(unwrapFirstError) ], consumeErrors=True)).addErrback(unwrapFirstError)
class Linearizer(object): class Linearizer(object):
@ -181,7 +181,8 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer self.key_to_defer[key] = new_defer
if current_defer: if current_defer:
yield preserve_context_over_deferred(current_defer) with PreserveLoggingContext():
yield current_defer
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -264,7 +265,7 @@ class ReadWriteLock(object):
curr_readers.clear() curr_readers.clear()
self.key_to_current_writer[key] = new_defer self.key_to_current_writer[key] = new_defer
yield defer.gatherResults(to_wait_on) yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():

View file

@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res return res
def preserve_context_over_deferred(deferred): def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
""" """
current_context = LoggingContext.current_context() if context is None:
d = _PreservingContextDeferred(current_context) context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d) deferred.chainDeferred(d)
return d return d
@ -316,7 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs): def g(*args, **kwargs):
with PreserveLoggingContext(current): with PreserveLoggingContext(current):
return f(*args, **kwargs) res = f(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(
res, context=LoggingContext.sentinel
)
else:
return res
return g return g

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@ -55,12 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
given events given events
events ([synapse.events.EventBase]): list of events to filter events ([synapse.events.EventBase]): list of events to filter
""" """
forgotten = yield defer.gatherResults([ forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)( preserve_fn(store.who_forgot_in_room)(
room_id, room_id,
) )
for room_id in frozenset(e.room_id for e in events) for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True) ], consumeErrors=True))
# Set of membership event_ids that have been forgotten # Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset( event_id_forgotten = frozenset(