mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-30 14:33:54 +01:00
Merge branch 'develop' into notifier_unify
Conflicts: synapse/notifier.py
This commit is contained in:
commit
4429e4bf24
28 changed files with 628 additions and 303 deletions
|
@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
|
||||||
#rm $DIR/etc/$port.config
|
#rm $DIR/etc/$port.config
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--generate-config \
|
--generate-config \
|
||||||
|
--enable_registration \
|
||||||
-H "localhost:$https_port" \
|
-H "localhost:$https_port" \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
|
||||||
from twisted.internet.protocol import Factory
|
from twisted.internet.protocol import Factory
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import (
|
||||||
|
preserve_context_over_fn, preserve_context_over_deferred
|
||||||
|
)
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
protocol = yield preserve_context_over_fn(
|
||||||
protocol = yield endpoint.connect(factory)
|
endpoint.connect, factory
|
||||||
server_response, server_certificate = yield protocol.remote_key
|
)
|
||||||
defer.returnValue((server_response, server_certificate))
|
server_response, server_certificate = yield preserve_context_over_deferred(
|
||||||
return
|
protocol.remote_key
|
||||||
|
)
|
||||||
|
defer.returnValue((server_response, server_certificate))
|
||||||
|
return
|
||||||
except SynapseKeyClientError as e:
|
except SynapseKeyClientError as e:
|
||||||
logger.exception("Error getting key for %r" % (server_name,))
|
logger.exception("Error getting key for %r" % (server_name,))
|
||||||
if e.status.startswith("4"):
|
if e.status.startswith("4"):
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
|
|
||||||
from synapse.util.async import create_observer
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
@ -111,6 +111,10 @@ class Keyring(object):
|
||||||
|
|
||||||
if download is None:
|
if download is None:
|
||||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
download = self._get_server_verify_key_impl(server_name, key_ids)
|
||||||
|
download = ObservableDeferred(
|
||||||
|
download,
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
self.key_downloads[server_name] = download
|
self.key_downloads[server_name] = download
|
||||||
|
|
||||||
@download.addBoth
|
@download.addBoth
|
||||||
|
@ -118,7 +122,7 @@ class Keyring(object):
|
||||||
del self.key_downloads[server_name]
|
del self.key_downloads[server_name]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
r = yield create_observer(download)
|
r = yield download.observe()
|
||||||
defer.returnValue(r)
|
defer.returnValue(r)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +96,7 @@ class FederationBase(object):
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[do(pdu) for pdu in pdus],
|
[do(pdu) for pdu in pdus],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(signed_pdus)
|
defer.returnValue(signed_pdus)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ from .federation_base import FederationBase
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -123,29 +122,28 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
results = []
|
||||||
results = []
|
|
||||||
|
|
||||||
for pdu in pdu_list:
|
for pdu in pdu_list:
|
||||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
d = self._handle_new_pdu(transaction.origin, pdu)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield d
|
yield d
|
||||||
results.append({})
|
results.append({})
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
self.send_failure(e, transaction.origin)
|
self.send_failure(e, transaction.origin)
|
||||||
results.append({"error": str(e)})
|
results.append({"error": str(e)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results.append({"error": str(e)})
|
results.append({"error": str(e)})
|
||||||
logger.exception("Failed to handle PDU")
|
logger.exception("Failed to handle PDU")
|
||||||
|
|
||||||
if hasattr(transaction, "edus"):
|
if hasattr(transaction, "edus"):
|
||||||
for edu in [Edu(**x) for x in transaction.edus]:
|
for edu in [Edu(**x) for x in transaction.edus]:
|
||||||
self.received_edu(
|
self.received_edu(
|
||||||
transaction.origin,
|
transaction.origin,
|
||||||
edu.edu_type,
|
edu.edu_type,
|
||||||
edu.content
|
edu.content
|
||||||
)
|
)
|
||||||
|
|
||||||
for failure in getattr(transaction, "pdu_failures", []):
|
for failure in getattr(transaction, "pdu_failures", []):
|
||||||
logger.info("Got failure %r", failure)
|
logger.info("Got failure %r", failure)
|
||||||
|
|
|
@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,10 +139,11 @@ class BaseHandler(object):
|
||||||
"Failed to get destination from event %s", s.event_id
|
"Failed to get destination from event %s", s.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't block waiting on waking up all the listeners.
|
with PreserveLoggingContext():
|
||||||
notify_d = self.notifier.on_new_room_event(
|
# Don't block waiting on waking up all the listeners.
|
||||||
event, extra_users=extra_users
|
notify_d = self.notifier.on_new_room_event(
|
||||||
)
|
event, extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
|
||||||
# thundering herds on restart.
|
# thundering herds on restart.
|
||||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
events, tokens = yield self.notifier.get_events_for(
|
||||||
events, tokens = yield self.notifier.get_events_for(
|
auth_user, room_ids, pagin_config, timeout
|
||||||
auth_user, room_ids, pagin_config, timeout
|
)
|
||||||
)
|
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,11 @@
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, StoreError,
|
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
@ -29,6 +31,8 @@ from synapse.crypto.event_signing import (
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -197,9 +201,10 @@ class FederationHandler(BaseHandler):
|
||||||
target_user = UserID.from_string(target_user_id)
|
target_user = UserID.from_string(target_user_id)
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=extra_users
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -218,10 +223,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def backfill(self, dest, room_id, limit):
|
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||||
"""
|
"""
|
||||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
if not extremities:
|
||||||
|
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||||
|
|
||||||
pdus = yield self.replication_layer.backfill(
|
pdus = yield self.replication_layer.backfill(
|
||||||
dest,
|
dest,
|
||||||
|
@ -248,6 +254,138 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def maybe_backfill(self, room_id, current_depth):
|
||||||
|
"""Checks the database to see if we should backfill before paginating,
|
||||||
|
and if so do.
|
||||||
|
"""
|
||||||
|
extremities = yield self.store.get_oldest_events_with_depth_in_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not extremities:
|
||||||
|
logger.debug("Not backfilling as no extremeties found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we reached a point where we should start backfilling.
|
||||||
|
sorted_extremeties_tuple = sorted(
|
||||||
|
extremities.items(),
|
||||||
|
key=lambda e: -int(e[1])
|
||||||
|
)
|
||||||
|
max_depth = sorted_extremeties_tuple[0][1]
|
||||||
|
|
||||||
|
if current_depth > max_depth:
|
||||||
|
logger.debug(
|
||||||
|
"Not backfilling as we don't need to. %d < %d",
|
||||||
|
max_depth, current_depth,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Now we need to decide which hosts to hit first.
|
||||||
|
|
||||||
|
# First we try hosts that are already in the room
|
||||||
|
# TODO: HEURISTIC ALERT.
|
||||||
|
|
||||||
|
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
|
def get_domains_from_state(state):
|
||||||
|
joined_users = [
|
||||||
|
(state_key, int(event.depth))
|
||||||
|
for (e_type, state_key), event in state.items()
|
||||||
|
if e_type == EventTypes.Member
|
||||||
|
and event.membership == Membership.JOIN
|
||||||
|
]
|
||||||
|
|
||||||
|
joined_domains = {}
|
||||||
|
for u, d in joined_users:
|
||||||
|
try:
|
||||||
|
dom = UserID.from_string(u).domain
|
||||||
|
old_d = joined_domains.get(dom)
|
||||||
|
if old_d:
|
||||||
|
joined_domains[dom] = min(d, old_d)
|
||||||
|
else:
|
||||||
|
joined_domains[dom] = d
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||||
|
|
||||||
|
curr_domains = get_domains_from_state(curr_state)
|
||||||
|
|
||||||
|
likely_domains = [
|
||||||
|
domain for domain, depth in curr_domains
|
||||||
|
]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def try_backfill(domains):
|
||||||
|
# TODO: Should we try multiple of these at a time?
|
||||||
|
for dom in domains:
|
||||||
|
try:
|
||||||
|
events = yield self.backfill(
|
||||||
|
dom, room_id,
|
||||||
|
limit=100,
|
||||||
|
extremities=[e for e in extremities.keys()]
|
||||||
|
)
|
||||||
|
except SynapseError:
|
||||||
|
logger.info(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except CodeMessageException as e:
|
||||||
|
if 400 <= e.code < 500:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except NotRetryingDestination as e:
|
||||||
|
logger.info(e.message)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to backfill from %s because %s",
|
||||||
|
dom, e,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if events:
|
||||||
|
defer.returnValue(True)
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
success = yield try_backfill(likely_domains)
|
||||||
|
if success:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
# Huh, well *those* domains didn't work out. Lets try some domains
|
||||||
|
# from the time.
|
||||||
|
|
||||||
|
tried_domains = set(likely_domains)
|
||||||
|
|
||||||
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
|
states = yield defer.gatherResults([
|
||||||
|
self.state_handler.resolve_state_groups([e])
|
||||||
|
for e in event_ids
|
||||||
|
])
|
||||||
|
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||||
|
|
||||||
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
|
likely_domains = get_domains_from_state(states[e_id])
|
||||||
|
|
||||||
|
success = yield try_backfill([
|
||||||
|
dom for dom in likely_domains
|
||||||
|
if dom not in tried_domains
|
||||||
|
])
|
||||||
|
if success:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
tried_domains.update(likely_domains)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_invite(self, target_host, event):
|
def send_invite(self, target_host, event):
|
||||||
""" Sends the invite to the remote server for signing.
|
""" Sends the invite to the remote server for signing.
|
||||||
|
@ -431,9 +569,10 @@ class FederationHandler(BaseHandler):
|
||||||
auth_events=auth_events,
|
auth_events=auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
new_event, extra_users=[joinee]
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
new_event, extra_users=[joinee]
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -512,9 +651,10 @@ class FederationHandler(BaseHandler):
|
||||||
target_user = UserID.from_string(target_user_id)
|
target_user = UserID.from_string(target_user_id)
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=extra_users
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, extra_users=extra_users
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -594,9 +734,10 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
target_user = UserID.from_string(event.state_key)
|
target_user = UserID.from_string(event.state_key)
|
||||||
d = self.notifier.on_new_room_event(
|
with PreserveLoggingContext():
|
||||||
event, extra_users=[target_user],
|
d = self.notifier.on_new_room_event(
|
||||||
)
|
event, extra_users=[target_user],
|
||||||
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
def log_failure(f):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
|
@ -921,7 +1062,7 @@ class FederationHandler(BaseHandler):
|
||||||
if d in have_events and not have_events[d]
|
if d in have_events and not have_events[d]
|
||||||
],
|
],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
if different_events:
|
if different_events:
|
||||||
local_view = dict(auth_events)
|
local_view = dict(auth_events)
|
||||||
|
|
|
@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, RoomStreamToken
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -89,9 +90,19 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
if not pagin_config.from_token:
|
if not pagin_config.from_token:
|
||||||
pagin_config.from_token = (
|
pagin_config.from_token = (
|
||||||
yield self.hs.get_event_sources().get_current_token()
|
yield self.hs.get_event_sources().get_current_token(
|
||||||
|
direction='b'
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
|
||||||
|
if room_token.topological is None:
|
||||||
|
raise SynapseError(400, "Invalid token")
|
||||||
|
|
||||||
|
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||||
|
room_id, room_token.topological
|
||||||
|
)
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
events, next_key = yield data_source.get_pagination_rows(
|
events, next_key = yield data_source.get_pagination_rows(
|
||||||
|
@ -303,7 +314,7 @@ class MessageHandler(BaseHandler):
|
||||||
event.room_id
|
event.room_id
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
|
@ -328,7 +339,7 @@ class MessageHandler(BaseHandler):
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[handle_room(e) for e in room_list],
|
[handle_room(e) for e in room_list],
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"rooms": rooms_ret,
|
"rooms": rooms_ret,
|
||||||
|
|
|
@ -18,14 +18,15 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -143,7 +144,7 @@ class PresenceHandler(BaseHandler):
|
||||||
self._remote_offline_serials = []
|
self._remote_offline_serials = []
|
||||||
|
|
||||||
# map any user to a UserPresenceCache
|
# map any user to a UserPresenceCache
|
||||||
self._user_cachemap = {}
|
self._user_cachemap = OrderedDict() # keep them sorted by serial
|
||||||
self._user_cachemap_latest_serial = 0
|
self._user_cachemap_latest_serial = 0
|
||||||
|
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
|
@ -165,6 +166,14 @@ class PresenceHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return UserPresenceCache()
|
return UserPresenceCache()
|
||||||
|
|
||||||
|
def _bump_serial(self, user=None):
|
||||||
|
self._user_cachemap_latest_serial += 1
|
||||||
|
|
||||||
|
if user:
|
||||||
|
# Move to end
|
||||||
|
cache = self._user_cachemap.pop(user)
|
||||||
|
self._user_cachemap[user] = cache
|
||||||
|
|
||||||
def registered_user(self, user):
|
def registered_user(self, user):
|
||||||
return self.store.create_presence(user.localpart)
|
return self.store.create_presence(user.localpart)
|
||||||
|
|
||||||
|
@ -278,15 +287,14 @@ class PresenceHandler(BaseHandler):
|
||||||
now_online = state["presence"] != PresenceState.OFFLINE
|
now_online = state["presence"] != PresenceState.OFFLINE
|
||||||
was_polling = target_user in self._user_cachemap
|
was_polling = target_user in self._user_cachemap
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
if now_online and not was_polling:
|
||||||
if now_online and not was_polling:
|
self.start_polling_presence(target_user, state=state)
|
||||||
self.start_polling_presence(target_user, state=state)
|
elif not now_online and was_polling:
|
||||||
elif not now_online and was_polling:
|
self.stop_polling_presence(target_user)
|
||||||
self.stop_polling_presence(target_user)
|
|
||||||
|
|
||||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||||
# we don't have to do this all the time
|
# we don't have to do this all the time
|
||||||
self.changed_presencelike_data(target_user, state)
|
self.changed_presencelike_data(target_user, state)
|
||||||
|
|
||||||
def bump_presence_active_time(self, user, now=None):
|
def bump_presence_active_time(self, user, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
|
@ -301,7 +309,7 @@ class PresenceHandler(BaseHandler):
|
||||||
def changed_presencelike_data(self, user, state):
|
def changed_presencelike_data(self, user, state):
|
||||||
statuscache = self._get_or_make_usercache(user)
|
statuscache = self._get_or_make_usercache(user)
|
||||||
|
|
||||||
self._user_cachemap_latest_serial += 1
|
self._bump_serial(user=user)
|
||||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||||
|
|
||||||
return self.push_presence(user, statuscache=statuscache)
|
return self.push_presence(user, statuscache=statuscache)
|
||||||
|
@ -323,7 +331,7 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
# No actual update but we need to bump the serial anyway for the
|
# No actual update but we need to bump the serial anyway for the
|
||||||
# event source
|
# event source
|
||||||
self._user_cachemap_latest_serial += 1
|
self._bump_serial()
|
||||||
statuscache.update({}, serial=self._user_cachemap_latest_serial)
|
statuscache.update({}, serial=self._user_cachemap_latest_serial)
|
||||||
|
|
||||||
self.push_update_to_local_and_remote(
|
self.push_update_to_local_and_remote(
|
||||||
|
@ -408,10 +416,10 @@ class PresenceHandler(BaseHandler):
|
||||||
yield self.store.set_presence_list_accepted(
|
yield self.store.set_presence_list_accepted(
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
with PreserveLoggingContext():
|
|
||||||
self.start_polling_presence(
|
self.start_polling_presence(
|
||||||
observer_user, target_user=observed_user
|
observer_user, target_user=observed_user
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def deny_presence(self, observed_user, observer_user):
|
def deny_presence(self, observed_user, observer_user):
|
||||||
|
@ -430,10 +438,9 @@ class PresenceHandler(BaseHandler):
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
self.stop_polling_presence(
|
||||||
self.stop_polling_presence(
|
observer_user, target_user=observed_user
|
||||||
observer_user, target_user=observed_user
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_presence_list(self, observer_user, accepted=None):
|
def get_presence_list(self, observer_user, accepted=None):
|
||||||
|
@ -706,7 +713,7 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
statuscache = self._get_or_make_usercache(user)
|
statuscache = self._get_or_make_usercache(user)
|
||||||
|
|
||||||
self._user_cachemap_latest_serial += 1
|
self._bump_serial(user=user)
|
||||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||||
|
|
||||||
if not observers and not room_ids:
|
if not observers and not room_ids:
|
||||||
|
@ -766,8 +773,7 @@ class PresenceHandler(BaseHandler):
|
||||||
if not self._remote_sendmap[user]:
|
if not self._remote_sendmap[user]:
|
||||||
del self._remote_sendmap[user]
|
del self._remote_sendmap[user]
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
||||||
|
@ -812,10 +818,11 @@ class PresenceHandler(BaseHandler):
|
||||||
|
|
||||||
def push_update_to_clients(self, observed_user, users_to_push=[],
|
def push_update_to_clients(self, observed_user, users_to_push=[],
|
||||||
room_ids=[], statuscache=None):
|
room_ids=[], statuscache=None):
|
||||||
self.notifier.on_new_user_event(
|
with PreserveLoggingContext():
|
||||||
users_to_push,
|
self.notifier.on_new_user_event(
|
||||||
room_ids,
|
users_to_push,
|
||||||
)
|
room_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PresenceEventSource(object):
|
class PresenceEventSource(object):
|
||||||
|
@ -866,10 +873,15 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
updates = []
|
updates = []
|
||||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
||||||
for observed_user in cachemap.keys():
|
for observed_user in reversed(cachemap.keys()):
|
||||||
cached = cachemap[observed_user]
|
cached = cachemap[observed_user]
|
||||||
|
|
||||||
if cached.serial <= from_key or cached.serial > max_serial:
|
# Since this is ordered in descending order of serial, we can just
|
||||||
|
# stop once we've seen enough
|
||||||
|
if cached.serial <= from_key:
|
||||||
|
break
|
||||||
|
|
||||||
|
if cached.serial > max_serial:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not (yield self.is_visible(observer_user, observed_user)):
|
if not (yield self.is_visible(observer_user, observed_user)):
|
||||||
|
|
|
@ -17,8 +17,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler):
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
(displayname, avatar_url) = yield defer.gatherResults(
|
||||||
(displayname, avatar_url) = yield defer.gatherResults(
|
[
|
||||||
[
|
self.store.get_profile_displayname(user.localpart),
|
||||||
self.store.get_profile_displayname(user.localpart),
|
self.store.get_profile_avatar_url(user.localpart),
|
||||||
self.store.get_profile_avatar_url(user.localpart),
|
],
|
||||||
],
|
consumeErrors=True
|
||||||
consumeErrors=True
|
).addErrback(unwrapFirstError)
|
||||||
)
|
|
||||||
|
|
||||||
state["displayname"] = displayname
|
state["displayname"] = displayname
|
||||||
state["avatar_url"] = avatar_url
|
state["avatar_url"] = avatar_url
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ._base import BaseHandler
|
||||||
from synapse.types import UserID, RoomAlias, RoomID
|
from synapse.types import UserID, RoomAlias, RoomID
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import StoreError, SynapseError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils, unwrapFirstError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
|
||||||
|
@ -537,7 +537,7 @@ class RoomListHandler(BaseHandler):
|
||||||
for room in chunk
|
for room in chunk
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
for i, room in enumerate(chunk):
|
for i, room in enumerate(chunk):
|
||||||
room["num_joined_members"] = len(results[i])
|
room["num_joined_members"] = len(results[i])
|
||||||
|
@ -577,8 +577,8 @@ class RoomEventSource(object):
|
||||||
|
|
||||||
defer.returnValue((events, end_key))
|
defer.returnValue((events, end_key))
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self, direction='f'):
|
||||||
return self.store.get_room_events_max_id()
|
return self.store.get_room_events_max_id(direction)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_pagination_rows(self, user, config, key):
|
def get_pagination_rows(self, user, config, key):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -216,7 +217,8 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
self._latest_room_serial += 1
|
self._latest_room_serial += 1
|
||||||
self._room_serials[room_id] = self._latest_room_serial
|
self._room_serials[room_id] = self._latest_room_serial
|
||||||
|
|
||||||
self.notifier.on_new_user_event(rooms=[room_id])
|
with PreserveLoggingContext():
|
||||||
|
self.notifier.on_new_user_event(rooms=[room_id])
|
||||||
|
|
||||||
|
|
||||||
class TypingNotificationEventSource(object):
|
class TypingNotificationEventSource(object):
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -61,7 +62,10 @@ class SimpleHttpClient(object):
|
||||||
# A small wrapper around self.agent.request() so we can easily attach
|
# A small wrapper around self.agent.request() so we can easily attach
|
||||||
# counters to it
|
# counters to it
|
||||||
outgoing_requests_counter.inc(method)
|
outgoing_requests_counter.inc(method)
|
||||||
d = self.agent.request(method, *args, **kwargs)
|
d = preserve_context_over_fn(
|
||||||
|
self.agent.request,
|
||||||
|
method, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _cb(response):
|
def _cb(response):
|
||||||
incoming_responses_counter.inc(method, response.code)
|
incoming_responses_counter.inc(method, response.code)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
from synapse.http.endpoint import matrix_federation_endpoint
|
from synapse.http.endpoint import matrix_federation_endpoint
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
|
||||||
producer = body_callback(method, url_bytes, headers_dict)
|
producer = body_callback(method, url_bytes, headers_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with PreserveLoggingContext():
|
request_deferred = preserve_context_over_fn(
|
||||||
request_deferred = self.agent.request(
|
self.agent.request,
|
||||||
destination,
|
destination,
|
||||||
endpoint,
|
endpoint,
|
||||||
method,
|
method,
|
||||||
path_bytes,
|
path_bytes,
|
||||||
param_bytes,
|
param_bytes,
|
||||||
query_bytes,
|
query_bytes,
|
||||||
Headers(headers_dict),
|
Headers(headers_dict),
|
||||||
producer
|
producer
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield self.clock.time_bound_deferred(
|
response = yield self.clock.time_bound_deferred(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
time_out=60,
|
time_out=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Got response to %s", method)
|
logger.debug("Got response to %s", method)
|
||||||
break
|
break
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
|
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from syutil.jsonutil import (
|
from syutil.jsonutil import (
|
||||||
|
@ -85,7 +85,9 @@ def request_handler(request_handler):
|
||||||
"Received request: %s %s",
|
"Received request: %s %s",
|
||||||
request.method, request.path
|
request.method, request.path
|
||||||
)
|
)
|
||||||
yield request_handler(self, request)
|
d = request_handler(self, request)
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
yield d
|
||||||
code = request.code
|
code = request.code
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -193,12 +192,11 @@ class Notifier(object):
|
||||||
|
|
||||||
logger.debug("on_new_room_event listeners %s", listeners)
|
logger.debug("on_new_room_event listeners %s", listeners)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
for listener in listeners:
|
||||||
for listener in listeners:
|
try:
|
||||||
try:
|
listener.notify(self)
|
||||||
listener.notify(self)
|
except:
|
||||||
except:
|
logger.exception("Failed to notify listener")
|
||||||
logger.exception("Failed to notify listener")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -225,12 +223,11 @@ class Notifier(object):
|
||||||
|
|
||||||
listeners |= room_listeners
|
listeners |= room_listeners
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
for listener in listeners:
|
||||||
for listener in listeners:
|
try:
|
||||||
try:
|
listener.notify(self)
|
||||||
listener.notify(self)
|
except:
|
||||||
except:
|
logger.exception("Failed to notify listener")
|
||||||
logger.exception("Failed to notify listener")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user, rooms, timeout, callback,
|
def wait_for_events(self, user, rooms, timeout, callback,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.util.async import create_observer
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
|
||||||
download = self.downloads.get(key)
|
download = self.downloads.get(key)
|
||||||
if download is None:
|
if download is None:
|
||||||
download = self._get_remote_media_impl(server_name, media_id)
|
download = self._get_remote_media_impl(server_name, media_id)
|
||||||
|
download = ObservableDeferred(
|
||||||
|
download,
|
||||||
|
consumeErrors=True
|
||||||
|
)
|
||||||
self.downloads[key] = download
|
self.downloads[key] = download
|
||||||
|
|
||||||
@download.addBoth
|
@download.addBoth
|
||||||
def callback(media_info):
|
def callback(media_info):
|
||||||
del self.downloads[key]
|
del self.downloads[key]
|
||||||
return media_info
|
return media_info
|
||||||
return create_observer(download)
|
return download.observe()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_remote_media_impl(self, server_name, media_id):
|
def _get_remote_media_impl(self, server_name, media_id):
|
||||||
|
|
|
@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.lrucache import LruCache
|
from synapse.util.lrucache import LruCache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -420,10 +420,11 @@ class SQLBaseStore(object):
|
||||||
self._txn_perf_counters.update(desc, start, end)
|
self._txn_perf_counters.update(desc, start, end)
|
||||||
sql_txn_timer.inc_by(duration, desc)
|
sql_txn_timer.inc_by(duration, desc)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
result = yield preserve_context_over_fn(
|
||||||
result = yield self._db_pool.runWithConnection(
|
self._db_pool.runWithConnection,
|
||||||
inner_func, *args, **kwargs
|
inner_func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for after_callback, after_args in after_callbacks:
|
for after_callback, after_args in after_callbacks:
|
||||||
after_callback(*after_args)
|
after_callback(*after_args)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
|
@ -79,6 +79,28 @@ class EventFederationStore(SQLBaseStore):
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_oldest_events_with_depth_in_room",
|
||||||
|
self.get_oldest_events_with_depth_in_room_txn,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
|
||||||
|
sql = (
|
||||||
|
"SELECT b.event_id, MAX(e.depth) FROM events as e"
|
||||||
|
" INNER JOIN event_edges as g"
|
||||||
|
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
|
||||||
|
" INNER JOIN event_backward_extremities as b"
|
||||||
|
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
|
||||||
|
" WHERE b.room_id = ? AND g.is_state is ?"
|
||||||
|
" GROUP BY b.event_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (room_id, False,))
|
||||||
|
|
||||||
|
return dict(txn.fetchall())
|
||||||
|
|
||||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||||
return self._simple_select_onecol_txn(
|
return self._simple_select_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -247,11 +269,13 @@ class EventFederationStore(SQLBaseStore):
|
||||||
do_insert = depth < min_depth if min_depth else True
|
do_insert = depth < min_depth if min_depth else True
|
||||||
|
|
||||||
if do_insert:
|
if do_insert:
|
||||||
self._simple_insert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="room_depth",
|
table="room_depth",
|
||||||
values={
|
keyvalues={
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
"min_depth": depth,
|
"min_depth": depth,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -306,31 +330,27 @@ class EventFederationStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(query, (event_id, room_id))
|
txn.execute(query, (event_id, room_id))
|
||||||
|
|
||||||
# Insert all the prev_events as a backwards thing, they'll get
|
query = (
|
||||||
# deleted in a second if they're incorrect anyway.
|
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||||
self._simple_insert_many_txn(
|
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||||
txn,
|
" SELECT 1 FROM event_backward_extremities"
|
||||||
table="event_backward_extremities",
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
values=[
|
" )"
|
||||||
{
|
" AND NOT EXISTS ("
|
||||||
"event_id": e_id,
|
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
|
||||||
"room_id": room_id,
|
" )"
|
||||||
}
|
|
||||||
for e_id, _ in prev_events
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also delete from the backwards extremities table all ones that
|
txn.executemany(query, [
|
||||||
# reference events that we have already seen
|
(e_id, room_id, e_id, room_id, e_id, room_id, )
|
||||||
|
for e_id, _ in prev_events
|
||||||
|
])
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"DELETE FROM event_backward_extremities WHERE EXISTS ("
|
"DELETE FROM event_backward_extremities"
|
||||||
"SELECT 1 FROM events "
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
"WHERE "
|
|
||||||
"event_backward_extremities.event_id = events.event_id "
|
|
||||||
"AND not events.outlier "
|
|
||||||
")"
|
|
||||||
)
|
)
|
||||||
txn.execute(query)
|
txn.execute(query, (event_id, room_id))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||||
|
|
|
@ -135,19 +135,17 @@ class EventsStore(SQLBaseStore):
|
||||||
outlier = event.internal_metadata.is_outlier()
|
outlier = event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
if not outlier:
|
if not outlier:
|
||||||
self._store_state_groups_txn(txn, event, context)
|
|
||||||
|
|
||||||
self._update_min_depth_for_room_txn(
|
self._update_min_depth_for_room_txn(
|
||||||
txn,
|
txn,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
event.depth
|
event.depth
|
||||||
)
|
)
|
||||||
|
|
||||||
have_persisted = self._simple_select_one_onecol_txn(
|
have_persisted = self._simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="events",
|
||||||
keyvalues={"event_id": event.event_id},
|
keyvalues={"event_id": event.event_id},
|
||||||
retcol="event_id",
|
retcols=["event_id", "outlier"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -162,7 +160,9 @@ class EventsStore(SQLBaseStore):
|
||||||
# if we are persisting an event that we had persisted as an outlier,
|
# if we are persisting an event that we had persisted as an outlier,
|
||||||
# but is no longer one.
|
# but is no longer one.
|
||||||
if have_persisted:
|
if have_persisted:
|
||||||
if not outlier:
|
if not outlier and have_persisted["outlier"]:
|
||||||
|
self._store_state_groups_txn(txn, event, context)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
" WHERE event_id = ?"
|
" WHERE event_id = ?"
|
||||||
|
@ -182,6 +182,9 @@ class EventsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not outlier:
|
||||||
|
self._store_state_groups_txn(txn, event, context)
|
||||||
|
|
||||||
self._handle_prev_events(
|
self._handle_prev_events(
|
||||||
txn,
|
txn,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
|
|
|
@ -37,11 +37,9 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
|
||||||
_TOPOLOGICAL_TOKEN = "topological"
|
_TOPOLOGICAL_TOKEN = "topological"
|
||||||
|
|
||||||
|
|
||||||
class _StreamToken(namedtuple("_StreamToken", "topological stream")):
|
def lower_bound(token):
|
||||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
if token.topological is None:
|
||||||
|
return "(%d < %s)" % (token.stream, "stream_ordering")
|
||||||
|
else:
|
||||||
|
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
||||||
|
token.topological, "topological_ordering",
|
||||||
|
token.topological, "topological_ordering",
|
||||||
|
token.stream, "stream_ordering",
|
||||||
|
)
|
||||||
|
|
||||||
s0 s1
|
|
||||||
| |
|
|
||||||
[0] V [1] V [2]
|
|
||||||
|
|
||||||
Tokens can either be a point in the live event stream or a cursor going
|
def upper_bound(token):
|
||||||
through historic events.
|
if token.topological is None:
|
||||||
|
return "(%d >= %s)" % (token.stream, "stream_ordering")
|
||||||
When traversing the live event stream events are ordered by when they
|
else:
|
||||||
arrived at the homeserver.
|
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
||||||
|
token.topological, "topological_ordering",
|
||||||
When traversing historic events the events are ordered by their depth in
|
token.topological, "topological_ordering",
|
||||||
the event graph "topological_ordering" and then by when they arrived at the
|
token.stream, "stream_ordering",
|
||||||
homeserver "stream_ordering".
|
)
|
||||||
|
|
||||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
|
||||||
event it comes after. Historic tokens start with a "t" followed by the
|
|
||||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
|
||||||
followed by the "stream_ordering" id of the event it comes after.
|
|
||||||
"""
|
|
||||||
__slots__ = []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == 's':
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
if string[0] == 't':
|
|
||||||
parts = string[1:].split('-', 1)
|
|
||||||
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_stream_token(cls, string):
|
|
||||||
try:
|
|
||||||
if string[0] == 's':
|
|
||||||
return cls(topological=None, stream=int(string[1:]))
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.topological is not None:
|
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
|
||||||
else:
|
|
||||||
return "s%d" % (self.stream,)
|
|
||||||
|
|
||||||
def lower_bound(self):
|
|
||||||
if self.topological is None:
|
|
||||||
return "(%d < %s)" % (self.stream, "stream_ordering")
|
|
||||||
else:
|
|
||||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.stream, "stream_ordering",
|
|
||||||
)
|
|
||||||
|
|
||||||
def upper_bound(self):
|
|
||||||
if self.topological is None:
|
|
||||||
return "(%d >= %s)" % (self.stream, "stream_ordering")
|
|
||||||
else:
|
|
||||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.topological, "topological_ordering",
|
|
||||||
self.stream, "stream_ordering",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamStore(SQLBaseStore):
|
class StreamStore(SQLBaseStore):
|
||||||
|
@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
|
||||||
limit = MAX_STREAM_SIZE
|
limit = MAX_STREAM_SIZE
|
||||||
|
|
||||||
# From and to keys should be integers from ordering.
|
# From and to keys should be integers from ordering.
|
||||||
from_id = _StreamToken.parse_stream_token(from_key)
|
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||||
to_id = _StreamToken.parse_stream_token(to_key)
|
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
defer.returnValue(([], to_key))
|
defer.returnValue(([], to_key))
|
||||||
|
@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
|
||||||
limit = MAX_STREAM_SIZE
|
limit = MAX_STREAM_SIZE
|
||||||
|
|
||||||
# From and to keys should be integers from ordering.
|
# From and to keys should be integers from ordering.
|
||||||
from_id = _StreamToken.parse_stream_token(from_key)
|
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||||
to_id = _StreamToken.parse_stream_token(to_key)
|
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
return defer.succeed(([], to_key))
|
return defer.succeed(([], to_key))
|
||||||
|
@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
|
||||||
args = [False, room_id]
|
args = [False, room_id]
|
||||||
if direction == 'b':
|
if direction == 'b':
|
||||||
order = "DESC"
|
order = "DESC"
|
||||||
bounds = _StreamToken.parse(from_key).upper_bound()
|
bounds = upper_bound(RoomStreamToken.parse(from_key))
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (
|
||||||
bounds, _StreamToken.parse(to_key).lower_bound()
|
bounds, lower_bound(RoomStreamToken.parse(to_key))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
order = "ASC"
|
order = "ASC"
|
||||||
bounds = _StreamToken.parse(from_key).lower_bound()
|
bounds = lower_bound(RoomStreamToken.parse(from_key))
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (
|
||||||
bounds, _StreamToken.parse(to_key).upper_bound()
|
bounds, upper_bound(RoomStreamToken.parse(to_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
if int(limit) > 0:
|
if int(limit) > 0:
|
||||||
|
@ -333,7 +281,7 @@ class StreamStore(SQLBaseStore):
|
||||||
# when we are going backwards so we subtract one from the
|
# when we are going backwards so we subtract one from the
|
||||||
# stream part.
|
# stream part.
|
||||||
toke -= 1
|
toke -= 1
|
||||||
next_token = str(_StreamToken(topo, toke))
|
next_token = str(RoomStreamToken(topo, toke))
|
||||||
else:
|
else:
|
||||||
# TODO (erikj): We should work out what to do here instead.
|
# TODO (erikj): We should work out what to do here instead.
|
||||||
next_token = to_key if to_key else from_key
|
next_token = to_key if to_key else from_key
|
||||||
|
@ -354,7 +302,7 @@ class StreamStore(SQLBaseStore):
|
||||||
with_feedback=False, from_token=None):
|
with_feedback=False, from_token=None):
|
||||||
# TODO (erikj): Handle compressed feedback
|
# TODO (erikj): Handle compressed feedback
|
||||||
|
|
||||||
end_token = _StreamToken.parse_stream_token(end_token)
|
end_token = RoomStreamToken.parse_stream_token(end_token)
|
||||||
|
|
||||||
if from_token is None:
|
if from_token is None:
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -365,7 +313,7 @@ class StreamStore(SQLBaseStore):
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from_token = _StreamToken.parse_stream_token(from_token)
|
from_token = RoomStreamToken.parse_stream_token(from_token)
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_ordering, topological_ordering, event_id"
|
"SELECT stream_ordering, topological_ordering, event_id"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
|
@ -395,7 +343,7 @@ class StreamStore(SQLBaseStore):
|
||||||
# stream part.
|
# stream part.
|
||||||
topo = rows[0]["topological_ordering"]
|
topo = rows[0]["topological_ordering"]
|
||||||
toke = rows[0]["stream_ordering"] - 1
|
toke = rows[0]["stream_ordering"] - 1
|
||||||
start_token = str(_StreamToken(topo, toke))
|
start_token = str(RoomStreamToken(topo, toke))
|
||||||
|
|
||||||
token = (start_token, str(end_token))
|
token = (start_token, str(end_token))
|
||||||
else:
|
else:
|
||||||
|
@ -416,9 +364,25 @@ class StreamStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_events_max_id(self):
|
def get_room_events_max_id(self, direction='f'):
|
||||||
token = yield self._stream_id_gen.get_max_token(self)
|
token = yield self._stream_id_gen.get_max_token(self)
|
||||||
defer.returnValue("s%d" % (token,))
|
if direction != 'b':
|
||||||
|
defer.returnValue("s%d" % (token,))
|
||||||
|
else:
|
||||||
|
topo = yield self.runInteraction(
|
||||||
|
"_get_max_topological_txn", self._get_max_topological_txn
|
||||||
|
)
|
||||||
|
defer.returnValue("t%d-%d" % (topo, token))
|
||||||
|
|
||||||
|
def _get_max_topological_txn(self, txn):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT MAX(topological_ordering) FROM events"
|
||||||
|
" WHERE outlier = ?",
|
||||||
|
(False,)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = txn.fetchall()
|
||||||
|
return rows[0][0] if rows else 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_min_token(self):
|
def _get_min_token(self):
|
||||||
|
@ -439,5 +403,5 @@ class StreamStore(SQLBaseStore):
|
||||||
stream = row["stream_ordering"]
|
stream = row["stream_ordering"]
|
||||||
topo = event.depth
|
topo = event.depth
|
||||||
internal = event.internal_metadata
|
internal = event.internal_metadata
|
||||||
internal.before = str(_StreamToken(topo, stream - 1))
|
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||||
internal.after = str(_StreamToken(topo, stream))
|
internal.after = str(RoomStreamToken(topo, stream))
|
||||||
|
|
|
@ -31,7 +31,7 @@ class NullSource(object):
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events_for_user(self, user, from_key, limit):
|
||||||
return defer.succeed(([], from_key))
|
return defer.succeed(([], from_key))
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self, direction='f'):
|
||||||
return defer.succeed(0)
|
return defer.succeed(0)
|
||||||
|
|
||||||
def get_pagination_rows(self, user, pagination_config, key):
|
def get_pagination_rows(self, user, pagination_config, key):
|
||||||
|
@ -52,10 +52,10 @@ class EventSources(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_token(self):
|
def get_current_token(self, direction='f'):
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
yield self.sources["room"].get_current_key()
|
yield self.sources["room"].get_current_key(direction)
|
||||||
),
|
),
|
||||||
presence_key=(
|
presence_key=(
|
||||||
yield self.sources["presence"].get_current_key()
|
yield self.sources["presence"].get_current_key()
|
||||||
|
|
|
@ -121,4 +121,56 @@ class StreamToken(
|
||||||
return StreamToken(**d)
|
return StreamToken(**d)
|
||||||
|
|
||||||
|
|
||||||
|
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
|
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||||
|
|
||||||
|
s0 s1
|
||||||
|
| |
|
||||||
|
[0] V [1] V [2]
|
||||||
|
|
||||||
|
Tokens can either be a point in the live event stream or a cursor going
|
||||||
|
through historic events.
|
||||||
|
|
||||||
|
When traversing the live event stream events are ordered by when they
|
||||||
|
arrived at the homeserver.
|
||||||
|
|
||||||
|
When traversing historic events the events are ordered by their depth in
|
||||||
|
the event graph "topological_ordering" and then by when they arrived at the
|
||||||
|
homeserver "stream_ordering".
|
||||||
|
|
||||||
|
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||||
|
event it comes after. Historic tokens start with a "t" followed by the
|
||||||
|
"topological_ordering" id of the event it comes after, follewed by "-",
|
||||||
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
|
"""
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, string):
|
||||||
|
try:
|
||||||
|
if string[0] == 's':
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
if string[0] == 't':
|
||||||
|
parts = string[1:].split('-', 1)
|
||||||
|
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_stream_token(cls, string):
|
||||||
|
try:
|
||||||
|
if string[0] == 's':
|
||||||
|
return cls(topological=None, stream=int(string[1:]))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.topological is not None:
|
||||||
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
|
else:
|
||||||
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, task
|
from twisted.internet import defer, reactor, task
|
||||||
|
|
||||||
|
@ -23,6 +23,12 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrapFirstError(failure):
|
||||||
|
# defer.gatherResults and DeferredLists wrap failures.
|
||||||
|
failure.trap(defer.FirstError)
|
||||||
|
return failure.value.subFailure
|
||||||
|
|
||||||
|
|
||||||
class Clock(object):
|
class Clock(object):
|
||||||
"""A small utility that obtains current time-of-day so that time may be
|
"""A small utility that obtains current time-of-day so that time may be
|
||||||
mocked during unit-tests.
|
mocked during unit-tests.
|
||||||
|
@ -50,9 +56,12 @@ class Clock(object):
|
||||||
current_context = LoggingContext.current_context()
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
def wrapped_callback():
|
def wrapped_callback():
|
||||||
LoggingContext.thread_local.current_context = current_context
|
with PreserveLoggingContext():
|
||||||
callback()
|
LoggingContext.thread_local.current_context = current_context
|
||||||
return reactor.callLater(delay, wrapped_callback)
|
callback()
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
return reactor.callLater(delay, wrapped_callback)
|
||||||
|
|
||||||
def cancel_call_later(self, timer):
|
def cancel_call_later(self, timer):
|
||||||
timer.cancel()
|
timer.cancel()
|
||||||
|
|
|
@ -16,15 +16,13 @@
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from .logcontext import PreserveLoggingContext
|
from .logcontext import preserve_context_over_deferred
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def sleep(seconds):
|
def sleep(seconds):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
reactor.callLater(seconds, d.callback, seconds)
|
reactor.callLater(seconds, d.callback, seconds)
|
||||||
with PreserveLoggingContext():
|
return preserve_context_over_deferred(d)
|
||||||
yield d
|
|
||||||
|
|
||||||
|
|
||||||
def run_on_reactor():
|
def run_on_reactor():
|
||||||
|
@ -34,20 +32,56 @@ def run_on_reactor():
|
||||||
return sleep(0)
|
return sleep(0)
|
||||||
|
|
||||||
|
|
||||||
def create_observer(deferred):
|
class ObservableDeferred(object):
|
||||||
"""Creates a deferred that observes the result or failure of the given
|
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||||
deferred *without* affecting the given deferred.
|
observer deferreds do not affect the callback chain of the original
|
||||||
|
deferred.
|
||||||
|
|
||||||
|
If consumeErrors is true errors will be captured from the origin deferred.
|
||||||
"""
|
"""
|
||||||
d = defer.Deferred()
|
|
||||||
|
|
||||||
def callback(r):
|
__slots__ = ["_deferred", "_observers", "_result"]
|
||||||
d.callback(r)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def errback(f):
|
def __init__(self, deferred, consumeErrors=False):
|
||||||
d.errback(f)
|
object.__setattr__(self, "_deferred", deferred)
|
||||||
return f
|
object.__setattr__(self, "_result", None)
|
||||||
|
object.__setattr__(self, "_observers", [])
|
||||||
|
|
||||||
deferred.addCallbacks(callback, errback)
|
def callback(r):
|
||||||
|
self._result = (True, r)
|
||||||
|
while self._observers:
|
||||||
|
try:
|
||||||
|
self._observers.pop().callback(r)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return r
|
||||||
|
|
||||||
return d
|
def errback(f):
|
||||||
|
self._result = (False, f)
|
||||||
|
while self._observers:
|
||||||
|
try:
|
||||||
|
self._observers.pop().errback(f)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if consumeErrors:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return f
|
||||||
|
|
||||||
|
deferred.addCallbacks(callback, errback)
|
||||||
|
|
||||||
|
def observe(self):
|
||||||
|
if not self._result:
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._observers.append(d)
|
||||||
|
return d
|
||||||
|
else:
|
||||||
|
success, res = self._result
|
||||||
|
return defer.succeed(res) if success else defer.fail(res)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._deferred, name)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
setattr(self._deferred, name, value)
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.logcontext import (
|
||||||
|
PreserveLoggingContext, preserve_context_over_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +97,6 @@ class Signal(object):
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def fire(self, *args, **kwargs):
|
def fire(self, *args, **kwargs):
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
|
@ -101,24 +104,28 @@ class Signal(object):
|
||||||
|
|
||||||
Returns a Deferred that will complete when all the observers have
|
Returns a Deferred that will complete when all the observers have
|
||||||
completed."""
|
completed."""
|
||||||
with PreserveLoggingContext():
|
|
||||||
deferreds = []
|
|
||||||
for observer in self.observers:
|
|
||||||
d = defer.maybeDeferred(observer, *args, **kwargs)
|
|
||||||
|
|
||||||
def eb(failure):
|
def do(observer):
|
||||||
logger.warning(
|
def eb(failure):
|
||||||
"%s signal observer %s failed: %r",
|
logger.warning(
|
||||||
self.name, observer, failure,
|
"%s signal observer %s failed: %r",
|
||||||
exc_info=(
|
self.name, observer, failure,
|
||||||
failure.type,
|
exc_info=(
|
||||||
failure.value,
|
failure.type,
|
||||||
failure.getTracebackObject()))
|
failure.value,
|
||||||
if not self.suppress_failures:
|
failure.getTracebackObject()))
|
||||||
failure.raiseException()
|
if not self.suppress_failures:
|
||||||
deferreds.append(d.addErrback(eb))
|
return failure
|
||||||
results = []
|
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
|
||||||
for deferred in deferreds:
|
|
||||||
result = yield deferred
|
with PreserveLoggingContext():
|
||||||
results.append(result)
|
deferreds = [
|
||||||
defer.returnValue(results)
|
do(observer)
|
||||||
|
for observer in self.observers
|
||||||
|
]
|
||||||
|
|
||||||
|
d = defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
|
|
||||||
|
d.addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
return preserve_context_over_deferred(d)
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -129,3 +131,53 @@ class PreserveLoggingContext(object):
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
"""Restores the current logging context"""
|
"""Restores the current logging context"""
|
||||||
LoggingContext.thread_local.current_context = self.current_context
|
LoggingContext.thread_local.current_context = self.current_context
|
||||||
|
|
||||||
|
if self.current_context is not LoggingContext.sentinel:
|
||||||
|
if self.current_context.parent_context is None:
|
||||||
|
logger.warn(
|
||||||
|
"Restoring dead context: %s",
|
||||||
|
self.current_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_context_over_fn(fn, *args, **kwargs):
|
||||||
|
"""Takes a function and invokes it with the given arguments, but removes
|
||||||
|
and restores the current logging context while doing so.
|
||||||
|
|
||||||
|
If the result is a deferred, call preserve_context_over_deferred before
|
||||||
|
returning it.
|
||||||
|
"""
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
res = fn(*args, **kwargs)
|
||||||
|
|
||||||
|
if isinstance(res, defer.Deferred):
|
||||||
|
return preserve_context_over_deferred(res)
|
||||||
|
else:
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_context_over_deferred(deferred):
|
||||||
|
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||||
|
be invoked with the current context.
|
||||||
|
"""
|
||||||
|
d = defer.Deferred()
|
||||||
|
|
||||||
|
current_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
def cb(res):
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
LoggingContext.thread_local.current_context = current_context
|
||||||
|
res = d.callback(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def eb(failure):
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
LoggingContext.thread_local.current_context = current_context
|
||||||
|
res = d.errback(failure)
|
||||||
|
return res
|
||||||
|
|
||||||
|
if deferred.called:
|
||||||
|
return deferred
|
||||||
|
|
||||||
|
deferred.addCallbacks(cb, eb)
|
||||||
|
return d
|
||||||
|
|
Loading…
Reference in a new issue