From 5d89a526f1bd217148b3f82efd6ec156a78af894 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 12 Mar 2019 21:57:47 +0000 Subject: [PATCH] Factor per-destination stuff out of TransactionQueue This is easier than having to have a million fields keyed on destination. --- synapse/federation/transaction_queue.py | 314 ++++++++++++++---------- 1 file changed, 182 insertions(+), 132 deletions(-) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 288cb5045c..c1f6985ae4 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( HttpResponseException, RequestSendFailed, ) +from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.metrics import ( LaterGauge, @@ -36,6 +37,7 @@ from synapse.metrics import ( sent_transactions_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage import UserPresenceState from synapse.util import logcontext from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter @@ -80,73 +82,47 @@ class TransactionQueue(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.transaction_actions = TransactionActions(self.store) - - self.transport_layer = hs.get_federation_transport_client() self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id - # Is a mapping from destinations -> deferreds. Used to keep track - # of which destinations have transactions in flight and when they are - # done - self.pending_transactions = {} + self._transaction_sender = TransactionSender(hs) + + # map from destination to PerDestinationQueue + self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", "", [], - lambda: len(self.pending_transactions), + lambda: sum( + 1 for d in self._per_destination_queues.values() + if d.transmission_loop_running + ), ) - # Is a mapping from destination -> list of - # tuple(pending pdus, deferred, order) - self.pending_pdus_by_dest = pdus = {} - # destination -> list of tuple(edu, deferred) - self.pending_edus_by_dest = edus = {} - # Map of user_id -> UserPresenceState for all the pending presence # to be sent out by user_id. Entries here get processed and put in # pending_presence_by_dest self.pending_presence = {} - # Map of destination -> user_id -> UserPresenceState of pending presence - # to be sent to each destinations - self.pending_presence_by_dest = presence = {} - - # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered - # based on their key (e.g. typing events by room_id) - # Map of destination -> (edu_type, key) -> Edu - self.pending_edus_keyed_by_dest = edus_keyed = {} - LaterGauge( "synapse_federation_transaction_queue_pending_pdus", "", [], - lambda: sum(map(len, pdus.values())), + lambda: sum( + d.pending_pdu_count() for d in self._per_destination_queues.values() + ), ) LaterGauge( "synapse_federation_transaction_queue_pending_edus", "", [], - lambda: ( - sum(map(len, edus.values())) - + sum(map(len, presence.values())) - + sum(map(len, edus_keyed.values())) + lambda: sum( + d.pending_edu_count() for d in self._per_destination_queues.values() ), ) - # destination -> stream_id of last successfully sent to-device message. - # NB: may be a long or an int. - self.last_device_stream_id_by_dest = {} - - # destination -> stream_id of last successfully sent device list - # update. - self.last_device_list_stream_id_by_dest = {} - - # HACK to get unique tx id - self._next_txn_id = int(self.clock.time_msec()) - self._order = 1 self._is_processing = False @@ -154,6 +130,13 @@ class TransactionQueue(object): self._processing_pending_presence = False + def _get_per_destination_queue(self, destination): + queue = self._per_destination_queues.get(destination) + if not queue: + queue = PerDestinationQueue(self.hs, self._transaction_sender, destination) + self._per_destination_queues[destination] = queue + return queue + def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -284,11 +267,7 @@ class TransactionQueue(object): sent_pdus_destination_dist_count.inc() for destination in destinations: - self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, order) - ) - - self._attempt_new_transaction(destination) + self._get_per_destination_queue(destination).send_pdu(pdu, order) @defer.inlineCallbacks def send_read_receipt(self, receipt): @@ -387,14 +366,7 @@ class TransactionQueue(object): for destination in destinations: if destination == self.server_name: continue - - self.pending_presence_by_dest.setdefault( - destination, {} - ).update({ - state.user_id: state for state in states - }) - - self._attempt_new_transaction(destination) + self._get_per_destination_queue(destination).send_presence(states) def build_and_send_edu(self, destination, edu_type, content, key=None): """Construct an Edu object, and queue it for sending @@ -425,73 +397,136 @@ class TransactionQueue(object): edu (Edu): edu to send key (Any|None): clobbering key for this edu """ + queue = self._get_per_destination_queue(edu.destination) if key: - self.pending_edus_keyed_by_dest.setdefault( - edu.destination, {} - )[(edu.edu_type, key)] = edu + queue.send_keyed_edu(edu, key) else: - self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu) - - self._attempt_new_transaction(edu.destination) + queue.send_edu(edu) def send_device_messages(self, destination): if destination == self.server_name: logger.info("Not sending device update to ourselves") return - self._attempt_new_transaction(destination) + self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self): return 0 - def _attempt_new_transaction(self, destination): + +class PerDestinationQueue(object): + """ + Manages the per-destination transmission queues. + """ + def __init__(self, hs, transaction_sender, destination): + self._server_name = hs.hostname + self._clock = hs.get_clock() + self._store = hs.get_datastore() + self._transaction_sender = transaction_sender + + self._destination = destination + self.transmission_loop_running = False + + # a list of tuples of (pending pdu, order) + self._pending_pdus = [] # type: list[tuple[EventBase, int]] + self._pending_edus = [] # type: list[Edu] + + # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered + # based on their key (e.g. typing events by room_id) + # Map of (edu_type, key) -> Edu + self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + + # Map of user_id -> UserPresenceState of pending presence to be sent to this + # destination + self._pending_presence = {} # type: dict[str, UserPresenceState] + + # stream_id of last successfully sent to-device message. + # NB: may be a long or an int. + self._last_device_stream_id = 0 + + # stream_id of last successfully sent device list update. + self._last_device_list_stream_id = 0 + + def pending_pdu_count(self): + return len(self._pending_pdus) + + def pending_edu_count(self): + return ( + len(self._pending_edus) + + len(self._pending_presence) + + len(self._pending_edus_keyed) + ) + + def send_pdu(self, pdu, order): + """Add a PDU to the queue, and start the transmission loop if neccessary + + Args: + pdu (EventBase): pdu to send + order (int): + """ + self._pending_pdus.append((pdu, order)) + self.attempt_new_transaction() + + def send_presence(self, states): + """Add presence updates to the queue. Start the transmission loop if neccessary. + + Args: + states (iterable[UserPresenceState]): presence to send + """ + self._pending_presence.update({ + state.user_id: state for state in states + }) + self.attempt_new_transaction() + + def send_keyed_edu(self, edu, key): + self._pending_edus_keyed[(edu.edu_type, key)] = edu + self.attempt_new_transaction() + + def send_edu(self, edu): + self._pending_edus.append(edu) + self.attempt_new_transaction() + + def attempt_new_transaction(self): """Try to start a new transaction to this destination If there is already a transaction in progress to this destination, returns immediately. Otherwise kicks off the process of sending a transaction in the background. - - Args: - destination (str): - - Returns: - None """ # list of (pending_pdu, deferred, order) - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. + if self.transmission_loop_running: + # XXX: this can get stuck on by a never-ending + # request at which point pending_pdus just keeps growing. # we need application-layer timeouts of some flavour of these # requests logger.debug( "TX [%s] Transaction already in progress", - destination + self._destination ) return - logger.debug("TX [%s] Starting transaction loop", destination) + logger.debug("TX [%s] Starting transaction loop", self._destination) run_as_background_process( "federation_transaction_transmission_loop", self._transaction_transmission_loop, - destination, ) @defer.inlineCallbacks - def _transaction_transmission_loop(self, destination): + def _transaction_transmission_loop(self): pending_pdus = [] try: - self.pending_transactions[destination] = 1 + self.transmission_loop_running = True # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - yield get_retry_limiter(destination, self.clock, self.store) + yield get_retry_limiter(self._destination, self._clock, self._store) pending_pdus = [] while True: device_message_edus, device_stream_id, dev_list_id = ( - yield self._get_new_device_messages(destination) + yield self._get_new_device_messages() ) # BEGIN CRITICAL SECTION @@ -501,39 +536,38 @@ class TransactionQueue(object): # where we decide if we actually have any pending messages) is # atomic - otherwise new PDUs or EDUs might arrive in the # meantime, but not get sent because we hold the - # pending_transactions flag. + # transmission_loop_running flag. - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_pdus = self._pending_pdus # We can only include at most 50 PDUs per transactions - pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] - if leftover_pdus: - self.pending_pdus_by_dest[destination] = leftover_pdus + pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] - pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_edus = self._pending_edus # We can only include at most 100 EDUs per transactions - pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] - if leftover_edus: - self.pending_edus_by_dest[destination] = leftover_edus - - pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_edus, self._pending_edus = pending_edus[:100], pending_edus[100:] pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() + self._pending_edus_keyed.values() ) + self._pending_edus_keyed = {} + pending_edus.extend(device_message_edus) + + pending_presence = self._pending_presence + self._pending_presence = {} if pending_presence: pending_edus.append( Edu( - origin=self.server_name, - destination=destination, + origin=self._server_name, + destination=self._destination, edu_type="m.presence", content={ "push": [ format_user_presence_state( - presence, self.clock.time_msec() + presence, self._clock.time_msec() ) for presence in pending_presence.values() ] @@ -543,19 +577,17 @@ class TransactionQueue(object): if pending_pdus: logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) + self._destination, len(pending_pdus)) if not pending_pdus and not pending_edus: - logger.debug("TX [%s] Nothing to send", destination) - self.last_device_stream_id_by_dest[destination] = ( - device_stream_id - ) + logger.debug("TX [%s] Nothing to send", self._destination) + self._last_device_stream_id = device_stream_id return # END CRITICAL SECTION - success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, + success = yield self._transaction_sender.send_new_transaction( + self._destination, pending_pdus, pending_edus ) if success: sent_transactions_counter.inc() @@ -565,23 +597,25 @@ class TransactionQueue(object): # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if device_message_edus: - yield self.store.delete_device_msgs_for_remote( - destination, device_stream_id + yield self._store.delete_device_msgs_for_remote( + self._destination, device_stream_id ) - logger.info("Marking as sent %r %r", destination, dev_list_id) - yield self.store.mark_as_sent_devices_by_remote( - destination, dev_list_id + logger.info( + "Marking as sent %r %r", self._destination, dev_list_id + ) + yield self._store.mark_as_sent_devices_by_remote( + self._destination, dev_list_id ) - self.last_device_stream_id_by_dest[destination] = device_stream_id - self.last_device_list_stream_id_by_dest[destination] = dev_list_id + self._last_device_stream_id = device_stream_id + self._last_device_list_stream_id = dev_list_id else: break except NotRetryingDestination as e: logger.debug( "TX [%s] not ready for retry yet (next retry at %s) - " "dropping transaction for now", - destination, + self._destination, datetime.datetime.fromtimestamp( (e.retry_last_ts + e.retry_interval) / 1000.0 ), @@ -591,51 +625,51 @@ class TransactionQueue(object): except HttpResponseException as e: logger.warning( "TX [%s] Received %d response to transaction: %s", - destination, e.code, e, + self._destination, e.code, e, ) except RequestSendFailed as e: - logger.warning("TX [%s] Failed to send transaction: %s", destination, e) + logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e) for p, _ in pending_pdus: logger.info("Failed to send event %s to %s", p.event_id, - destination) + self._destination) except Exception: logger.exception( "TX [%s] Failed to send transaction", - destination, + self._destination, ) for p, _ in pending_pdus: logger.info("Failed to send event %s to %s", p.event_id, - destination) + self._destination) finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) + # We want to be *very* sure we clear this after we stop processing + self.transmission_loop_running = False @defer.inlineCallbacks - def _get_new_device_messages(self, destination): - last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) - to_device_stream_id = self.store.get_to_device_stream_token() - contents, stream_id = yield self.store.get_new_device_msgs_for_remote( - destination, last_device_stream_id, to_device_stream_id + def _get_new_device_messages(self): + last_device_stream_id = self._last_device_stream_id + to_device_stream_id = self._store.get_to_device_stream_token() + contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + self._destination, last_device_stream_id, to_device_stream_id ) edus = [ Edu( - origin=self.server_name, - destination=destination, + origin=self._server_name, + destination=self._destination, edu_type="m.direct_to_device", content=content, ) for content in contents ] - last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) - now_stream_id, results = yield self.store.get_devices_by_remote( - destination, last_device_list + last_device_list = self._last_device_list_stream_id + now_stream_id, results = yield self._store.get_devices_by_remote( + self._destination, last_device_list ) edus.extend( Edu( - origin=self.server_name, - destination=destination, + origin=self._server_name, + destination=self._destination, edu_type="m.device_list_update", content=content, ) @@ -643,9 +677,25 @@ class TransactionQueue(object): ) defer.returnValue((edus, stream_id, now_stream_id)) + +class TransactionSender(object): + """Helper class which handles building and sending transactions + + shared between PerDestinationQueue objects + """ + def __init__(self, hs): + self._server_name = hs.hostname + self._clock = hs.get_clock() + self._store = hs.get_datastore() + self._transaction_actions = TransactionActions(self._store) + self._transport_layer = hs.get_federation_transport_client() + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + @measure_func("_send_new_transaction") @defer.inlineCallbacks - def _send_new_transaction(self, destination, pending_pdus, pending_edus): + def send_new_transaction(self, destination, pending_pdus, pending_edus): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -669,9 +719,9 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( - origin_server_ts=int(self.clock.time_msec()), + origin_server_ts=int(self._clock.time_msec()), transaction_id=txn_id, - origin=self.server_name, + origin=self._server_name, destination=destination, pdus=pdus, edus=edus, @@ -679,7 +729,7 @@ class TransactionQueue(object): self._next_txn_id += 1 - yield self.transaction_actions.prepare_to_send(transaction) + yield self._transaction_actions.prepare_to_send(transaction) logger.debug("TX [%s] Persisted transaction", destination) logger.info( @@ -697,7 +747,7 @@ class TransactionQueue(object): # keys work def json_data_cb(): data = transaction.get_dict() - now = int(self.clock.time_msec()) + now = int(self._clock.time_msec()) if "pdus" in data: for p in data["pdus"]: if "age_ts" in p: @@ -707,7 +757,7 @@ class TransactionQueue(object): return data try: - response = yield self.transport_layer.send_transaction( + response = yield self._transport_layer.send_transaction( transaction, json_data_cb ) code = 200 @@ -727,7 +777,7 @@ class TransactionQueue(object): destination, txn_id, code ) - yield self.transaction_actions.delivered( + yield self._transaction_actions.delivered( transaction, code, response )