0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-17 10:13:50 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into neilj/mau_tracker

This commit is contained in:
Neil Johnson 2018-08-03 13:40:47 +01:00
commit 897c51d274
25 changed files with 450 additions and 379 deletions

View file

@ -27,8 +27,9 @@ Describe here the problem that you are experiencing, or the feature you are requ
Describe how what happens differs from what you expected. Describe how what happens differs from what you expected.
If you can identify any relevant log snippets from _homeserver.log_, please include <!-- If you can identify any relevant log snippets from _homeserver.log_, please include
those here (please be careful to remove any personal or private data): those (please be careful to remove any personal or private data). Please surround them with
``` (three backticks, on a line on their own), so that they are formatted legibly. -->
### Version information ### Version information

View file

@ -1,3 +1,12 @@
Synapse 0.33.1 (2018-08-02)
===========================
SECURITY FIXES
--------------
- Fix a potential issue where servers could request events for rooms they have not joined. ([\#3641](https://github.com/matrix-org/synapse/issues/3641))
- Fix a potential issue where users could see events in private rooms before they joined. ([\#3642](https://github.com/matrix-org/synapse/issues/3642))
Synapse 0.33.0 (2018-07-19) Synapse 0.33.0 (2018-07-19)
=========================== ===========================

View file

@ -51,7 +51,7 @@ makes it horribly hard to review otherwise.
Changelog Changelog
~~~~~~~~~ ~~~~~~~~~
All changes, even minor ones, need a corresponding changelog All changes, even minor ones, need a corresponding changelog / newsfragment
entry. These are managed by Towncrier entry. These are managed by Towncrier
(https://github.com/hawkowl/towncrier). (https://github.com/hawkowl/towncrier).

1
changelog.d/3621.misc Normal file
View file

@ -0,0 +1 @@
Refactor FederationHandler to move DB writes into separate functions

1
changelog.d/3638.misc Normal file
View file

@ -0,0 +1 @@
Factor out exception handling in federation_client

1
changelog.d/3639.feature Normal file
View file

@ -0,0 +1 @@
When we fail to join a room over federation, pass the error code back to the client.

1
changelog.d/3645.misc Normal file
View file

@ -0,0 +1 @@
Update CONTRIBUTING to mention newsfragments.

View file

@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.33.0" __version__ = "0.33.1"

View file

@ -70,20 +70,6 @@ class CodeMessageException(RuntimeError):
self.code = code self.code = code
self.msg = msg self.msg = msg
def error_dict(self):
return cs_error(self.msg)
class MatrixCodeMessageException(CodeMessageException):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
super(MatrixCodeMessageException, self).__init__(code, msg)
self.errcode = errcode
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error """A base exception type for matrix errors which have an errcode and error
@ -109,38 +95,28 @@ class SynapseError(CodeMessageException):
self.errcode, self.errcode,
) )
@classmethod
def from_http_response_exception(cls, err):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to class ProxiedRequestError(SynapseError):
decide how to map the failure onto a matrix error to send back to the """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
client.
An attempt is made to parse the body of the http response as a matrix Attributes:
error. If that succeeds, the errcode and error message from the body errcode (str): Matrix error code e.g 'M_FORBIDDEN'
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Args:
err (HttpResponseException):
Returns:
SynapseError:
""" """
# try to parse the body as json, to get better errcode/msg, but def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
# default to M_UNKNOWN with the HTTP status as the error text super(ProxiedRequestError, self).__init__(
try: code, msg, errcode
j = json.loads(err.response) )
except ValueError: if additional_fields is None:
j = {} self._additional_fields = {}
errcode = j.get('errcode', Codes.UNKNOWN) else:
errmsg = j.get('error', err.msg) self._additional_fields = dict(additional_fields)
res = SynapseError(err.code, errmsg, errcode) def error_dict(self):
return res return cs_error(
self.msg,
self.errcode,
**self._additional_fields
)
class ConsentNotGivenError(SynapseError): class ConsentNotGivenError(SynapseError):
@ -309,14 +285,6 @@ class LimitExceededError(SynapseError):
) )
def cs_exception(exception):
if isinstance(exception, CodeMessageException):
return exception.error_dict()
else:
logger.error("Unknown exception type: %s", type(exception))
return {}
def cs_error(msg, code=Codes.UNKNOWN, **kwargs): def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server """ Utility method for constructing an error response for client-server
interactions. interactions.
@ -373,7 +341,7 @@ class HttpResponseException(CodeMessageException):
Represents an HTTP-level failure of an outbound request Represents an HTTP-level failure of an outbound request
Attributes: Attributes:
response (str): body of response response (bytes): body of response
""" """
def __init__(self, code, msg, response): def __init__(self, code, msg, response):
""" """
@ -381,7 +349,39 @@ class HttpResponseException(CodeMessageException):
Args: Args:
code (int): HTTP status code code (int): HTTP status code
msg (str): reason phrase from HTTP response status line msg (str): reason phrase from HTTP response status line
response (str): body of response response (bytes): body of response
""" """
super(HttpResponseException, self).__init__(code, msg) super(HttpResponseException, self).__init__(code, msg)
self.response = response self.response = response
def to_synapse_error(self):
"""Make a SynapseError based on an HTTPResponseException
This is useful when a proxied request has failed, and we need to
decide how to map the failure onto a matrix error to send back to the
client.
An attempt is made to parse the body of the http response as a matrix
error. If that succeeds, the errcode and error message from the body
are used as the errcode and error message in the new synapse error.
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
set to the reason code from the HTTP response.
Returns:
SynapseError:
"""
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json.loads(self.response)
except ValueError:
j = {}
if not isinstance(j, dict):
j = {}
errcode = j.pop('errcode', Codes.UNKNOWN)
errmsg = j.pop('error', self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)

View file

@ -217,6 +217,8 @@ class ServerConfig(Config):
# different cores. See # different cores. See
# https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
# #
# This setting requires the affinity package to be installed!
#
# cpu_affinity: 0xFFFFFFFF # cpu_affinity: 0xFFFFFFFF
# Whether to serve a web client from the HTTP/HTTPS root resource. # Whether to serve a web client from the HTTP/HTTPS root resource.

View file

@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000 PDU_RETRY_TIME_MS = 1 * 60 * 1000
class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
@ -458,6 +465,61 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
"""Try an operation on a series of servers, until it succeeds
Args:
description (unicode): description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try
callback (callable): Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
reraised.
Otherwise, if the callback raises an Exception the error is logged and the
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
Returns:
The [Deferred] result of callback, if it succeeds
Raises:
SynapseError if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
try:
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.message,
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)
raise RuntimeError("Failed to %s via any server", description)
def make_membership_event(self, destinations, room_id, user_id, membership, def make_membership_event(self, destinations, room_id, user_id, membership,
content={},): content={},):
""" """
@ -481,7 +543,7 @@ class FederationClient(FederationBase):
Deferred: resolves to a tuple of (origin (str), event (object)) Deferred: resolves to a tuple of (origin (str), event (object))
where origin is the remote homeserver which generated the event. where origin is the remote homeserver which generated the event.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
@ -492,11 +554,9 @@ class FederationClient(FederationBase):
"make_membership_event called with membership='%s', must be one of %s" % "make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships)) (membership, ",".join(valid_memberships))
) )
for destination in destinations:
if destination == self.server_name:
continue
try: @defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership destination, room_id, user_id, membership
) )
@ -518,24 +578,11 @@ class FederationClient(FederationBase):
defer.returnValue( defer.returnValue(
(destination, ev) (destination, ev)
) )
break
except CodeMessageException as e: return self._try_destination_list(
if not 500 <= e.code < 600: "make_" + membership, destinations, send_request,
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
) )
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
@ -552,17 +599,14 @@ class FederationClient(FederationBase):
giving the serer the event was sent to, ``state`` (?) and giving the serer the event was sent to, ``state`` (?) and
``auth_chain``. ``auth_chain``.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a 300/400 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue
try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
destination=destination, destination=destination,
@ -624,31 +668,22 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException as e: return self._try_destination_list("send_join", destinations, send_request)
if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu): def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try:
code, content = yield self.transport_layer.send_invite( code, content = yield self.transport_layer.send_invite(
destination=destination, destination=destination,
room_id=room_id, room_id=room_id,
event_id=event_id, event_id=event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
except HttpResponseException as e:
if e.code == 403:
raise e.to_synapse_error()
raise
pdu_dict = content["event"] pdu_dict = content["event"]
@ -663,7 +698,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu): def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers. """Sends a leave event to one of a list of homeservers.
@ -680,16 +714,13 @@ class FederationClient(FederationBase):
Return: Return:
Deferred: resolves to None. Deferred: resolves to None.
Fails with a ``CodeMessageException`` if the chosen remote server Fails with a ``SynapseError`` if the chosen remote server
returns a non-200 code. returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. Fails with a ``RuntimeError`` if no servers were reachable.
""" """
for destination in destinations: @defer.inlineCallbacks
if destination == self.server_name: def send_request(destination):
continue
try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave( _, content = yield self.transport_layer.send_leave(
destination=destination, destination=destination,
@ -700,15 +731,8 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
defer.returnValue(None) defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.") return self._try_destination_list("send_leave", destinations, send_request)
def get_public_rooms(self, destination, limit=None, since_token=None, def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,

View file

@ -426,6 +426,7 @@ class FederationServer(FederationBase):
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, origin,
event_id, event_id,
room_id,
signed_auth, signed_auth,
content.get("rejects", []), content.get("rejects", []),
content.get("missing", []), content.get("missing", []),

View file

@ -19,10 +19,12 @@ import random
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -129,11 +131,13 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler): class EventHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, user, event_id): def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event. """Retrieve a single specified event.
Args: Args:
user (synapse.types.UserID): The user requesting the event user (synapse.types.UserID): The user requesting the event
room_id (str|None): The expected room id. We'll return None if the
event's room does not match.
event_id (str): The event ID to obtain. event_id (str): The event ID to obtain.
Returns: Returns:
dict: An event, or None if there is no event matching this ID. dict: An event, or None if there is no event matching this ID.
@ -142,13 +146,26 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this AuthError if the user does not have the rights to inspect this
event. event.
""" """
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id, check_room_id=room_id)
if not event: if not event:
defer.returnValue(None) defer.returnValue(None)
return return
if hasattr(event, "room_id"): users = yield self.store.get_users_in_room(event.room_id)
yield self.auth.check_joined_room(event.room_id, user.to_string()) is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
self.store,
user.to_string(),
[event],
is_peeking=is_peeking
)
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
defer.returnValue(event) defer.returnValue(event)

View file

@ -400,7 +400,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( yield self._persist_auth_tree(
origin, auth_chain, state, event origin, auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
@ -444,7 +444,7 @@ class FederationHandler(BaseHandler):
yield self._handle_new_events(origin, event_infos) yield self._handle_new_events(origin, event_infos)
try: try:
context, event_stream_id, max_stream_id = yield self._handle_new_event( context = yield self._handle_new_event(
origin, origin,
event, event,
state=state, state=state,
@ -469,17 +469,6 @@ class FederationHandler(BaseHandler):
except StoreError: except StoreError:
logger.exception("Failed to store room.") logger.exception("Failed to store room.")
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally # Only fire user_joined_room if the user has acutally
@ -501,7 +490,7 @@ class FederationHandler(BaseHandler):
if newly_joined: if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield self.user_joined_room(user, event.room_id)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
@ -942,7 +931,7 @@ class FederationHandler(BaseHandler):
self.room_queues[room_id] = [] self.room_queues[room_id] = []
yield self.store.clean_room_for_join(room_id) yield self._clean_room_for_join(room_id)
handled_events = set() handled_events = set()
@ -981,15 +970,10 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( yield self._persist_auth_tree(
origin, auth_chain, state, event origin, auth_chain, state, event
) )
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems. # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event( context = yield self._handle_new_event(
origin, event origin, event
) )
@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler):
event.signatures, event.signatures,
) )
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield self.user_joined_room(user, event.room_id)
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)])
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event) defer.returnValue(event)
@ -1217,17 +1181,7 @@ class FederationHandler(BaseHandler):
) )
context = yield self.state_handler.compute_event_context(event) context = yield self.state_handler.compute_event_context(event)
yield self._persist_events([(event, context)])
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
)
target_user = UserID.from_string(event.state_key)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
defer.returnValue(event) defer.returnValue(event)
@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event( yield self._handle_new_event(
origin, event origin, event
) )
@ -1328,22 +1282,17 @@ class FederationHandler(BaseHandler):
event.signatures, event.signatures,
) )
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id,
)
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@ -1354,8 +1303,7 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e for e in state (e.type, e.state_key): e for e in state
} }
event = yield self.store.get_event(event_id) if event.is_state():
if event and event.is_state():
# Get previous state # Get previous state
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
@ -1374,6 +1322,10 @@ class FederationHandler(BaseHandler):
def get_state_ids_for_pdu(self, room_id, event_id): def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id,
)
state_groups = yield self.store.get_state_groups_ids( state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id] room_id, [event_id]
) )
@ -1382,8 +1334,7 @@ class FederationHandler(BaseHandler):
_, state = state_groups.items().pop() _, state = state_groups.items().pop()
results = state results = state
event = yield self.store.get_event(event_id) if event.is_state():
if event and event.is_state():
# Get previous state # Get previous state
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
@ -1472,9 +1423,8 @@ class FederationHandler(BaseHandler):
event, context event, context
) )
event_stream_id, max_stream_id = yield self.store.persist_event( yield self._persist_events(
event, [(event, context)],
context=context,
backfilled=backfilled, backfilled=backfilled,
) )
except: # noqa: E722, as we reraise the exception this is fine. except: # noqa: E722, as we reraise the exception this is fine.
@ -1487,15 +1437,7 @@ class FederationHandler(BaseHandler):
six.reraise(tp, value, tb) six.reraise(tp, value, tb)
if not backfilled: defer.returnValue(context)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
logcontext.run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False): def _handle_new_events(self, origin, event_infos, backfilled=False):
@ -1503,6 +1445,8 @@ class FederationHandler(BaseHandler):
should not depend on one another, e.g. this should be used to persist should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
Notifies about the events where appropriate.
""" """
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
@ -1517,7 +1461,7 @@ class FederationHandler(BaseHandler):
], consumeErrors=True, ], consumeErrors=True,
)) ))
yield self.store.persist_events( yield self._persist_events(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
@ -1529,7 +1473,8 @@ class FederationHandler(BaseHandler):
def _persist_auth_tree(self, origin, auth_events, state, event): def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event separately. Notifies about the persisted events
where appropriate.
Will attempt to fetch missing auth events. Will attempt to fetch missing auth events.
@ -1540,8 +1485,7 @@ class FederationHandler(BaseHandler):
event (Event) event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event Deferred
call for `event`
""" """
events_to_context = {} events_to_context = {}
for e in itertools.chain(auth_events, state): for e in itertools.chain(auth_events, state):
@ -1605,7 +1549,7 @@ class FederationHandler(BaseHandler):
raise raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events( yield self._persist_events(
[ [
(e, events_to_context[e.event_id]) (e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state) for e in itertools.chain(auth_events, state)
@ -1616,12 +1560,10 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
event_stream_id, max_stream_id = yield self.store.persist_event( yield self._persist_events(
event, new_event_context, [(event, new_event_context)],
) )
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None): def _prep_event(self, origin, event, state=None, auth_events=None):
""" """
@ -1678,8 +1620,19 @@ class FederationHandler(BaseHandler):
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
missing): missing):
in_room = yield self.auth.check_host_in_room(
room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
event = yield self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
# Just go through and process each event in `remote_auth_chain`. We # Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong. # don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain: for e in remote_auth_chain:
@ -1689,7 +1642,6 @@ class FederationHandler(BaseHandler):
pass pass
# Now get the current auth_chain for the event. # Now get the current auth_chain for the event.
event = yield self.store.get_event(event_id)
local_auth_chain = yield self.store.get_auth_chain( local_auth_chain = yield self.store.get_auth_chain(
[auth_id for auth_id, _ in event.auth_events], [auth_id for auth_id, _ in event.auth_events],
include_given=True include_given=True
@ -2347,3 +2299,69 @@ class FederationHandler(BaseHandler):
) )
if "valid" not in response or not response["valid"]: if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid") raise AuthError(403, "Third party certificate was invalid")
@defer.inlineCallbacks
def _persist_events(self, event_and_contexts, backfilled=False):
"""Persists events and tells the notifier/pushers about them, if
necessary.
Args:
event_and_contexts(list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether these events are a result of
backfilling or not
Returns:
Deferred
"""
max_stream_id = yield self.store.persist_events(
event_and_contexts,
backfilled=backfilled,
)
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the
event or not.
Args:
event (FrozenEvent)
max_stream_id (int): The max_stream_id returned by persist_events
"""
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
# We notify for memberships if its an invite for one of our
# users
if event.internal_metadata.is_outlier():
if event.membership != Membership.INVITE:
if not self.is_mine_id(target_user_id):
return
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
elif event.internal_metadata.is_outlier():
return
event_stream_id = event.internal_metadata.stream_ordering
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
logcontext.run_in_background(
self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id,
)
def _clean_room_for_join(self, room_id):
return self.store.clean_room_for_join(room_id)
def user_joined_room(self, user, room_id):
"""Called when a new user has joined the room
"""
return user_joined_room(self.distributor, user, room_id)

View file

@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
Codes, Codes,
MatrixCodeMessageException, HttpResponseException,
SynapseError, SynapseError,
) )
@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler):
) )
defer.returnValue(None) defer.returnValue(None)
data = {}
try: try:
data = yield self.http_client.get_json( data = yield self.http_client.get_json(
"https://%s%s" % ( "https://%s%s" % (
@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler):
), ),
{'sid': creds['sid'], 'client_secret': client_secret} {'sid': creds['sid'], 'client_secret': client_secret}
) )
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e) logger.info("getValidated3pid failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data: if 'medium' in data:
defer.returnValue(data) defer.returnValue(data)
@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler):
) )
logger.debug("bound threepid %r to %s", creds, mxid) logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e: except CodeMessageException as e:
data = json.loads(e.msg) data = json.loads(e.msg) # XXX WAT?
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler):
params params
) )
defer.returnValue(data) defer.returnValue(data)
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e.to_synapse_error()
@defer.inlineCallbacks @defer.inlineCallbacks
def requestMsisdnToken( def requestMsisdnToken(
@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler):
params params
) )
defer.returnValue(data) defer.returnValue(data)
except MatrixCodeMessageException as e: except HttpResponseException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e.to_synapse_error()

View file

@ -39,12 +39,7 @@ from twisted.web.client import (
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.api.errors import ( from synapse.api.errors import Codes, HttpResponseException, SynapseError
CodeMessageException,
Codes,
MatrixCodeMessageException,
SynapseError,
)
from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http import cancelled_to_request_timed_out_error, redact_uri
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
from synapse.util.async import add_timeout_to_deferred from synapse.util.async import add_timeout_to_deferred
@ -132,6 +127,11 @@ class SimpleHttpClient(object):
Returns: Returns:
Deferred[object]: parsed json Deferred[object]: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
@ -155,7 +155,10 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json_get_json(self, uri, post_json, headers=None): def post_json_get_json(self, uri, post_json, headers=None):
@ -169,6 +172,11 @@ class SimpleHttpClient(object):
Returns: Returns:
Deferred[object]: parsed json Deferred[object]: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
@ -193,9 +201,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else: else:
raise self._exceptionFromFailedRequest(response, body) raise HttpResponseException(response.code, response.phrase, body)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}, headers=None): def get_json(self, uri, args={}, headers=None):
@ -213,14 +219,12 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
On a non-2xx HTTP response. The response body will be used as the HttpResponseException On a non-2xx HTTP response.
error message.
ValueError: if the response was not JSON
""" """
try:
body = yield self.get_raw(uri, args, headers=headers) body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None): def put_json(self, uri, json_body, args={}, headers=None):
@ -239,7 +243,9 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
Raises: Raises:
On a non-2xx HTTP response. HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
@ -266,10 +272,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else: else:
# NB: This is explicitly not json.loads(body)'d because the contract raise HttpResponseException(response.code, response.phrase, body)
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_raw(self, uri, args={}, headers=None): def get_raw(self, uri, args={}, headers=None):
@ -287,8 +290,7 @@ class SimpleHttpClient(object):
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text. HTTP body at text.
Raises: Raises:
On a non-2xx HTTP response. The response body will be used as the HttpResponseException on a non-2xx HTTP response.
error message.
""" """
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
@ -311,16 +313,7 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(body) defer.returnValue(body)
else: else:
raise CodeMessageException(response.code, body) raise HttpResponseException(response.code, response.phrase, body)
def _exceptionFromFailedRequest(self, response, body):
try:
jsonBody = json.loads(body)
errcode = jsonBody['errcode']
error = jsonBody['error']
return MatrixCodeMessageException(response.code, error, errcode)
except (ValueError, KeyError):
return CodeMessageException(response.code, body)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out. # The two should be factored out.

View file

@ -36,7 +36,6 @@ from synapse.api.errors import (
Codes, Codes,
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
cs_exception,
) )
from synapse.http.request_metrics import requests_counter from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
@ -77,16 +76,13 @@ def wrap_json_request_handler(h):
def wrapped_request_handler(self, request): def wrapped_request_handler(self, request):
try: try:
yield h(self, request) yield h(self, request)
except CodeMessageException as e: except SynapseError as e:
code = e.code code = e.code
if isinstance(e, SynapseError):
logger.info( logger.info(
"%s SynapseError: %s - %s", request, code, e.msg "%s SynapseError: %s - %s", request, code, e.msg
) )
else:
logger.exception(e)
respond_with_json( respond_with_json(
request, code, cs_exception(e), send_cors=True, request, code, e.error_dict(), send_cors=True,
pretty_print=_request_user_agent_is_curl(request), pretty_print=_request_user_agent_is_curl(request),
) )

View file

@ -18,7 +18,7 @@ import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import MatrixCodeMessageException, SynapseError from synapse.api.errors import HttpResponseException
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
@ -56,11 +56,11 @@ def remote_join(client, host, port, requester, remote_room_hosts,
try: try:
result = yield client.post_json_get_json(uri, payload) result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)
@ -92,11 +92,11 @@ def remote_reject_invite(client, host, port, requester, remote_room_hosts,
try: try:
result = yield client.post_json_get_json(uri, payload) result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)
@ -131,11 +131,11 @@ def get_or_register_3pid_guest(client, host, port, requester,
try: try:
result = yield client.post_json_get_json(uri, payload) result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)
@ -165,11 +165,11 @@ def notify_user_membership_change(client, host, port, user_id, room_id, change):
try: try:
result = yield client.post_json_get_json(uri, payload) result = yield client.post_json_get_json(uri, payload)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)

View file

@ -18,11 +18,7 @@ import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import CodeMessageException, HttpResponseException
CodeMessageException,
MatrixCodeMessageException,
SynapseError,
)
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -83,11 +79,11 @@ def send_event_to_master(clock, store, client, host, port, requester, event, con
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
yield clock.sleep(1) yield clock.sleep(1)
except MatrixCodeMessageException as e: except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And
# importantly, not stack traces everywhere) # importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode) raise e.to_synapse_error()
defer.returnValue(result) defer.returnValue(result)

View file

@ -88,7 +88,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id) event = yield self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -506,7 +506,7 @@ class RoomEventServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id) event = yield self.event_handler.get_event(requester.user, room_id, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -379,7 +379,7 @@ class MediaRepository(object):
logger.warn("HTTP error fetching remote media %s/%s: %s", logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response) server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND: if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e) raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
except SynapseError: except SynapseError:

View file

@ -343,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
table="events", table="events",
keyvalues={ keyvalues={
"event_id": event_id, "event_id": event_id,
"room_id": room_id,
}, },
retcol="depth", retcol="depth",
allow_none=True, allow_none=True,

View file

@ -241,12 +241,18 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
""" """
Write events to the database Write events to the database
Args: Args:
events_and_contexts: list of tuples of (event, context) events_and_contexts: list of tuples of (event, context)
backfilled: ? backfilled (bool): Whether the results are retrieved from federation
via backfill or not. Used to determine if they're "new" events
which might update the current state etc.
Returns:
Deferred[int]: the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@ -263,10 +269,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return make_deferred_yieldable( yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue(max_persisted_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False): def persist_event(self, event, context, backfilled=False):

View file

@ -19,7 +19,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import NotFoundError
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -77,7 +77,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,
allow_none=False): allow_none=False, check_room_id=None):
"""Get an event from the database by event_id. """Get an event from the database by event_id.
Args: Args:
@ -88,7 +88,9 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events. allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if allow_none (bool): If True, return None if no event found, if
False throw an exception. False throw a NotFoundError
check_room_id (str|None): if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns: Returns:
Deferred : A FrozenEvent. Deferred : A FrozenEvent.
@ -100,10 +102,16 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
if not events and not allow_none: event = events[0] if events else None
raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None) if event is not None and check_room_id is not None:
if event.room_id != check_room_id:
event = None
if event is None and not allow_none:
raise NotFoundError("Could not find event %s" % (event_id,))
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True, def get_events(self, event_ids, check_redacted=True,