diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 595b1760f..dccc579ea 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,7 +46,7 @@ class SpamChecker(object): return self.spam_checker.check_event_for_spam(event) - def user_may_invite(self, userid, room_id): + def user_may_invite(self, inviter_userid, invitee_userid, room_id): """Checks if a given user may send an invite If this method returns false, the invite will be rejected. @@ -60,7 +60,7 @@ class SpamChecker(object): if self.spam_checker is None: return True - return self.spam_checker.user_may_invite(userid, room_id) + return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) def user_may_create_room(self, userid): """Checks if a given user may create a room diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 51e3fdea0..f00d59e70 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -12,14 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from twisted.internet import defer from .federation_base import FederationBase from .units import Transaction, Edu -from synapse.util.async import Linearizer +from synapse.util import async from synapse.util.logutils import log_function from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent @@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature import simplejson as json import logging +# when processing incoming transactions, we try to handle multiple rooms in +# parallel, up to this limit. +TRANSACTION_CONCURRENCY_LIMIT = 10 logger = logging.getLogger(__name__) @@ -52,7 +53,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() - self._server_linearizer = Linearizer("fed_server") + self._server_linearizer = async.Linearizer("fed_server") + self._transaction_linearizer = async.Linearizer("fed_txn_handler") # We cache responses to state queries, as they take a while and often # come in waves. @@ -109,25 +111,41 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_incoming_transaction(self, transaction_data): + # keep this as early as possible to make the calculated origin ts as + # accurate as possible. + request_time = self._clock.time_msec() + transaction = Transaction(**transaction_data) - received_pdus_counter.inc_by(len(transaction.pdus)) - - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] - - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] + if not transaction.transaction_id: + raise Exception("Transaction missing transaction_id") + if not transaction.origin: + raise Exception("Transaction missing origin") logger.debug("[%s] Got transaction", transaction.transaction_id) + # use a linearizer to ensure that we don't process the same transaction + # multiple times in parallel. + with (yield self._transaction_linearizer.queue( + (transaction.origin, transaction.transaction_id), + )): + result = yield self._handle_incoming_transaction( + transaction, request_time, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _handle_incoming_transaction(self, transaction, request_time): + """ Process an incoming transaction and return the HTTP response + + Args: + transaction (Transaction): incoming transaction + request_time (int): timestamp that the HTTP request arrived at + + Returns: + Deferred[(int, object)]: http response code and body + """ response = yield self.transaction_actions.have_responded(transaction) if response: @@ -140,42 +158,50 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - results = [] + received_pdus_counter.inc_by(len(transaction.pdus)) - for pdu in pdu_list: - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if transaction.origin != get_domain_from_id(pdu.event_id): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893). - if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) == 'join' - ): - logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, transaction.origin - ) - continue - else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, transaction.origin - ) + pdus_by_room = {} - try: - yield self._handle_received_pdu(transaction.origin, pdu) - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = request_time - int(p["age"]) + del p["age"] + + event = self.event_from_pdu_json(p) + room_id = event.room_id + pdus_by_room.setdefault(room_id, []).append(event) + + pdu_results = {} + + # we can process different rooms in parallel (which is useful if they + # require callouts to other servers to fetch missing events), but + # impose a limit to avoid going too crazy with ram/cpu. + @defer.inlineCallbacks + def process_pdus_for_room(room_id): + logger.debug("Processing PDUs for %s", room_id) + for pdu in pdus_by_room[room_id]: + event_id = pdu.event_id + try: + yield self._handle_received_pdu( + transaction.origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + self.send_failure(e, transaction.origin) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + pdu_results[event_id] = {"error": str(e)} + logger.exception("Failed to handle PDU %s", event_id) + + yield async.concurrently_execute( + process_pdus_for_room, pdus_by_room.keys(), + TRANSACTION_CONCURRENCY_LIMIT, + ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): @@ -188,14 +214,12 @@ class FederationServer(FederationBase): for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) - logger.debug("Returning: %s", str(results)) - response = { - "pdus": dict(zip( - (p.event_id for p in pdu_list), results - )), + "pdus": pdu_results, } + logger.debug("Returning: %s", str(response)) + yield self.transaction_actions.set_response( transaction, 200, response @@ -520,6 +544,30 @@ class FederationServer(FederationBase): Returns (Deferred): completes with None Raises: FederationError if the signatures / hash do not match """ + # check that it's actually being sent from a valid destination to + # workaround bug #1753 in 0.18.5 and 0.18.6 + if origin != get_domain_from_id(pdu.event_id): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893). + if not ( + pdu.type == 'm.room.member' and + pdu.content and + pdu.content.get("membership", None) == 'join' + ): + logger.info( + "Discarding PDU %s from invalid origin %s", + pdu.event_id, origin + ) + return + else: + logger.info( + "Accepting join PDU %s from %s", + pdu.event_id, origin + ) + # Check signature. try: pdu = yield self._check_sigs_and_hash(pdu) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 003eaba89..7a3c9cbb7 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -20,8 +20,8 @@ from .persistence import TransactionActions from .units import Transaction, Edu from synapse.api.errors import HttpResponseException +from synapse.util import logcontext from synapse.util.async import run_on_reactor -from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.metrics import measure_func from synapse.handlers.presence import format_user_presence_state, get_interested_remotes @@ -231,11 +231,9 @@ class TransactionQueue(object): (pdu, order) ) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) - @preserve_fn # the caller should not yield on this + @logcontext.preserve_fn # the caller should not yield on this @defer.inlineCallbacks def send_presence(self, states): """Send the new presence states to the appropriate destinations. @@ -299,7 +297,7 @@ class TransactionQueue(object): state.user_id: state for state in states }) - preserve_fn(self._attempt_new_transaction)(destination) + self._attempt_new_transaction(destination) def send_edu(self, destination, edu_type, content, key=None): edu = Edu( @@ -321,9 +319,7 @@ class TransactionQueue(object): else: self.pending_edus_by_dest.setdefault(destination, []).append(edu) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": @@ -336,9 +332,7 @@ class TransactionQueue(object): destination, [] ).append(failure) - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": @@ -347,15 +341,24 @@ class TransactionQueue(object): if not self.can_send_to(destination): return - preserve_context_over_fn( - self._attempt_new_transaction, destination - ) + self._attempt_new_transaction(destination) def get_current_token(self): return 0 - @defer.inlineCallbacks def _attempt_new_transaction(self, destination): + """Try to start a new transaction to this destination + + If there is already a transaction in progress to this destination, + returns immediately. Otherwise kicks off the process of sending a + transaction in the background. + + Args: + destination (str): + + Returns: + None + """ # list of (pending_pdu, deferred, order) if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending @@ -368,6 +371,19 @@ class TransactionQueue(object): ) return + logger.debug("TX [%s] Starting transaction loop", destination) + + # Drop the logcontext before starting the transaction. It doesn't + # really make sense to log all the outbound transactions against + # whatever path led us to this point: that's pretty arbitrary really. + # + # (this also means we can fire off _perform_transaction without + # yielding) + with logcontext.PreserveLoggingContext(): + self._transaction_transmission_loop(destination) + + @defer.inlineCallbacks + def _transaction_transmission_loop(self, destination): pending_pdus = [] try: self.pending_transactions[destination] = 1 diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 737fe518e..63c56a4a3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -14,7 +14,6 @@ # limitations under the License. """Contains handlers for federation events.""" -import synapse.util.logcontext from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -26,10 +25,7 @@ from synapse.api.errors import ( ) from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.logcontext import ( - preserve_fn, preserve_context_over_deferred -) +from synapse.util import unwrapFirstError, logcontext from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor, Linearizer @@ -592,9 +588,9 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch ) - results = yield preserve_context_over_deferred(defer.gatherResults( + results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.replication_layer.get_pdu)( + logcontext.preserve_fn(self.replication_layer.get_pdu)( [dest], event_id, outlier=True, @@ -786,10 +782,14 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) logger.debug("calling resolve_state_groups in _maybe_backfill") - states = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) - for e in event_ids - ])) + states = yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.preserve_fn(self.state_handler.resolve_state_groups)( + room_id, [e] + ) + for e in event_ids + ], consumeErrors=True, + )) states = dict(zip(event_ids, [s.state for s in states])) state_map = yield self.store.get_events( @@ -942,9 +942,7 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( - room_queue - ) + logcontext.preserve_fn(self._handle_queued_pdus)(room_queue) defer.returnValue(True) @@ -1071,6 +1069,9 @@ class FederationHandler(BaseHandler): """ event = pdu + if event.state_key is None: + raise SynapseError(400, "The invite event did not have a state key") + is_blocked = yield self.store.is_room_blocked(event.room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -1078,9 +1079,11 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") - if not self.spam_checker.user_may_invite(event.sender, event.room_id): + if not self.spam_checker.user_may_invite( + event.sender, event.state_key, event.room_id, + ): raise SynapseError( - 403, "This user is not permitted to send invites to this server" + 403, "This user is not permitted to send invites to this server/user" ) membership = event.content.get("membership") @@ -1091,9 +1094,6 @@ class FederationHandler(BaseHandler): if sender_domain != origin: raise SynapseError(400, "The invite event was not from the server sending it") - if event.state_key is None: - raise SynapseError(400, "The invite event did not have a state key") - if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") @@ -1436,7 +1436,7 @@ class FederationHandler(BaseHandler): if not backfilled: # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.pusher_pool.on_new_notifications)( + logcontext.preserve_fn(self.pusher_pool.on_new_notifications)( event_stream_id, max_stream_id ) @@ -1449,16 +1449,16 @@ class FederationHandler(BaseHandler): a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. """ - contexts = yield preserve_context_over_deferred(defer.gatherResults( + contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self._prep_event)( + logcontext.preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), auth_events=ev_info.get("auth_events"), ) for ev_info in event_infos - ] + ], consumeErrors=True, )) yield self.store.persist_events( @@ -1766,18 +1766,17 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self.store.get_event)( + different_events = yield logcontext.make_deferred_yieldable( + defer.gatherResults([ + logcontext.preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, ) for d in different_auth if d in have_events and not have_events[d] - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + ], consumeErrors=True) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 37985fa1f..36a8ef8ce 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -225,7 +225,7 @@ class RoomMemberHandler(BaseHandler): block_invite = True if not self.spam_checker.user_may_invite( - requester.user.to_string(), room_id, + requester.user.to_string(), target.to_string(), room_id, ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/synapse/state.py b/synapse/state.py index 390799fbd..dcdcdef65 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -288,6 +288,9 @@ class StateHandler(object): """ logger.debug("resolve_state_groups event_ids %s", event_ids) + # map from state group id to the state in that state group (where + # 'state' is a map from state key to event id) + # dict[int, dict[(str, str), str]] state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) @@ -320,11 +323,15 @@ class StateHandler(object): "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) + # build a map from state key to the event_ids which set that state. + # dict[(str, str), set[str]) state = {} for st in state_groups_ids.values(): for key, e_id in st.items(): state.setdefault(key, set()).add(e_id) + # build a map from state key to the event_ids which set that state, + # including only those where there are state keys in conflict. conflicted_state = { k: list(v) for k, v in state.items() @@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state, logger.info("Asking for %d conflicted events", len(needed_events)) + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict. state_map = yield state_map_factory(needed_events) + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) diff --git a/synapse/util/async.py b/synapse/util/async.py index 1453faf0e..bb252f75d 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -19,7 +19,7 @@ from twisted.internet import defer, reactor from .logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, ) -from synapse.util import unwrapFirstError +from synapse.util import logcontext, unwrapFirstError from contextlib import contextmanager @@ -155,7 +155,7 @@ def concurrently_execute(func, args, limit): except StopIteration: pass - return preserve_context_over_deferred(defer.gatherResults([ + return logcontext.make_deferred_yieldable(defer.gatherResults([ preserve_fn(_concurrently_execute_inner)() for _ in xrange(limit) ], consumeErrors=True)).addErrback(unwrapFirstError)