Factor per-destination stuff out of TransactionQueue

This is easier than having to have a million fields keyed on destination.
This commit is contained in:
Richard van der Hoff 2019-03-12 21:57:47 +00:00
parent 6accbd25bc
commit 5d89a526f1

View file

@ -27,6 +27,7 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
RequestSendFailed, RequestSendFailed,
) )
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
from synapse.metrics import ( from synapse.metrics import (
LaterGauge, LaterGauge,
@ -36,6 +37,7 @@ from synapse.metrics import (
sent_transactions_counter, sent_transactions_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import UserPresenceState
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@ -80,73 +82,47 @@ class TransactionQueue(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transaction_actions = TransactionActions(self.store)
self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
# Is a mapping from destinations -> deferreds. Used to keep track self._transaction_sender = TransactionSender(hs)
# of which destinations have transactions in flight and when they are
# done # map from destination to PerDestinationQueue
self.pending_transactions = {} self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_destinations", "synapse_federation_transaction_queue_pending_destinations",
"", "",
[], [],
lambda: len(self.pending_transactions), lambda: sum(
1 for d in self._per_destination_queues.values()
if d.transmission_loop_running
),
) )
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = pdus = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {}
# Map of user_id -> UserPresenceState for all the pending presence # Map of user_id -> UserPresenceState for all the pending presence
# to be sent out by user_id. Entries here get processed and put in # to be sent out by user_id. Entries here get processed and put in
# pending_presence_by_dest # pending_presence_by_dest
self.pending_presence = {} self.pending_presence = {}
# Map of destination -> user_id -> UserPresenceState of pending presence
# to be sent to each destinations
self.pending_presence_by_dest = presence = {}
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_pdus", "synapse_federation_transaction_queue_pending_pdus",
"", "",
[], [],
lambda: sum(map(len, pdus.values())), lambda: sum(
d.pending_pdu_count() for d in self._per_destination_queues.values()
),
) )
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_edus", "synapse_federation_transaction_queue_pending_edus",
"", "",
[], [],
lambda: ( lambda: sum(
sum(map(len, edus.values())) d.pending_edu_count() for d in self._per_destination_queues.values()
+ sum(map(len, presence.values()))
+ sum(map(len, edus_keyed.values()))
), ),
) )
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {}
# destination -> stream_id of last successfully sent device list
# update.
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
self._order = 1 self._order = 1
self._is_processing = False self._is_processing = False
@ -154,6 +130,13 @@ class TransactionQueue(object):
self._processing_pending_presence = False self._processing_pending_presence = False
def _get_per_destination_queue(self, destination):
queue = self._per_destination_queues.get(destination)
if not queue:
queue = PerDestinationQueue(self.hs, self._transaction_sender, destination)
self._per_destination_queues[destination] = queue
return queue
def notify_new_events(self, current_id): def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to """This gets called when we have some new events we might want to
send out to other servers. send out to other servers.
@ -284,11 +267,7 @@ class TransactionQueue(object):
sent_pdus_destination_dist_count.inc() sent_pdus_destination_dist_count.inc()
for destination in destinations: for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append( self._get_per_destination_queue(destination).send_pdu(pdu, order)
(pdu, order)
)
self._attempt_new_transaction(destination)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_read_receipt(self, receipt): def send_read_receipt(self, receipt):
@ -387,14 +366,7 @@ class TransactionQueue(object):
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
self._get_per_destination_queue(destination).send_presence(states)
self.pending_presence_by_dest.setdefault(
destination, {}
).update({
state.user_id: state for state in states
})
self._attempt_new_transaction(destination)
def build_and_send_edu(self, destination, edu_type, content, key=None): def build_and_send_edu(self, destination, edu_type, content, key=None):
"""Construct an Edu object, and queue it for sending """Construct an Edu object, and queue it for sending
@ -425,73 +397,136 @@ class TransactionQueue(object):
edu (Edu): edu to send edu (Edu): edu to send
key (Any|None): clobbering key for this edu key (Any|None): clobbering key for this edu
""" """
queue = self._get_per_destination_queue(edu.destination)
if key: if key:
self.pending_edus_keyed_by_dest.setdefault( queue.send_keyed_edu(edu, key)
edu.destination, {}
)[(edu.edu_type, key)] = edu
else: else:
self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu) queue.send_edu(edu)
self._attempt_new_transaction(edu.destination)
def send_device_messages(self, destination): def send_device_messages(self, destination):
if destination == self.server_name: if destination == self.server_name:
logger.info("Not sending device update to ourselves") logger.info("Not sending device update to ourselves")
return return
self._attempt_new_transaction(destination) self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self): def get_current_token(self):
return 0 return 0
def _attempt_new_transaction(self, destination):
class PerDestinationQueue(object):
"""
Manages the per-destination transmission queues.
"""
def __init__(self, hs, transaction_sender, destination):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._transaction_sender = transaction_sender
self._destination = destination
self.transmission_loop_running = False
# a list of tuples of (pending pdu, order)
self._pending_pdus = [] # type: list[tuple[EventBase, int]]
self._pending_edus = [] # type: list[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
self._pending_presence = {} # type: dict[str, UserPresenceState]
# stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self._last_device_stream_id = 0
# stream_id of last successfully sent device list update.
self._last_device_list_stream_id = 0
def pending_pdu_count(self):
return len(self._pending_pdus)
def pending_edu_count(self):
return (
len(self._pending_edus)
+ len(self._pending_presence)
+ len(self._pending_edus_keyed)
)
def send_pdu(self, pdu, order):
"""Add a PDU to the queue, and start the transmission loop if neccessary
Args:
pdu (EventBase): pdu to send
order (int):
"""
self._pending_pdus.append((pdu, order))
self.attempt_new_transaction()
def send_presence(self, states):
"""Add presence updates to the queue. Start the transmission loop if neccessary.
Args:
states (iterable[UserPresenceState]): presence to send
"""
self._pending_presence.update({
state.user_id: state for state in states
})
self.attempt_new_transaction()
def send_keyed_edu(self, edu, key):
self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction()
def send_edu(self, edu):
self._pending_edus.append(edu)
self.attempt_new_transaction()
def attempt_new_transaction(self):
"""Try to start a new transaction to this destination """Try to start a new transaction to this destination
If there is already a transaction in progress to this destination, If there is already a transaction in progress to this destination,
returns immediately. Otherwise kicks off the process of sending a returns immediately. Otherwise kicks off the process of sending a
transaction in the background. transaction in the background.
Args:
destination (str):
Returns:
None
""" """
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if self.transmission_loop_running:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: this can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing. # request at which point pending_pdus just keeps growing.
# we need application-layer timeouts of some flavour of these # we need application-layer timeouts of some flavour of these
# requests # requests
logger.debug( logger.debug(
"TX [%s] Transaction already in progress", "TX [%s] Transaction already in progress",
destination self._destination
) )
return return
logger.debug("TX [%s] Starting transaction loop", destination) logger.debug("TX [%s] Starting transaction loop", self._destination)
run_as_background_process( run_as_background_process(
"federation_transaction_transmission_loop", "federation_transaction_transmission_loop",
self._transaction_transmission_loop, self._transaction_transmission_loop,
destination,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _transaction_transmission_loop(self, destination): def _transaction_transmission_loop(self):
pending_pdus = [] pending_pdus = []
try: try:
self.pending_transactions[destination] = 1 self.transmission_loop_running = True
# This will throw if we wouldn't retry. We do this here so we fail # This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client, # quickly, but we will later check this again in the http client,
# hence why we throw the result away. # hence why we throw the result away.
yield get_retry_limiter(destination, self.clock, self.store) yield get_retry_limiter(self._destination, self._clock, self._store)
pending_pdus = [] pending_pdus = []
while True: while True:
device_message_edus, device_stream_id, dev_list_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages()
) )
# BEGIN CRITICAL SECTION # BEGIN CRITICAL SECTION
@ -501,39 +536,38 @@ class TransactionQueue(object):
# where we decide if we actually have any pending messages) is # where we decide if we actually have any pending messages) is
# atomic - otherwise new PDUs or EDUs might arrive in the # atomic - otherwise new PDUs or EDUs might arrive in the
# meantime, but not get sent because we hold the # meantime, but not get sent because we hold the
# pending_transactions flag. # transmission_loop_running flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self._pending_pdus
# We can only include at most 50 PDUs per transactions # We can only include at most 50 PDUs per transactions
pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
if leftover_pdus:
self.pending_pdus_by_dest[destination] = leftover_pdus
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self._pending_edus
# We can only include at most 100 EDUs per transactions # We can only include at most 100 EDUs per transactions
pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] pending_edus, self._pending_edus = pending_edus[:100], pending_edus[100:]
if leftover_edus:
self.pending_edus_by_dest[destination] = leftover_edus
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend( pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values() self._pending_edus_keyed.values()
) )
self._pending_edus_keyed = {}
pending_edus.extend(device_message_edus) pending_edus.extend(device_message_edus)
pending_presence = self._pending_presence
self._pending_presence = {}
if pending_presence: if pending_presence:
pending_edus.append( pending_edus.append(
Edu( Edu(
origin=self.server_name, origin=self._server_name,
destination=destination, destination=self._destination,
edu_type="m.presence", edu_type="m.presence",
content={ content={
"push": [ "push": [
format_user_presence_state( format_user_presence_state(
presence, self.clock.time_msec() presence, self._clock.time_msec()
) )
for presence in pending_presence.values() for presence in pending_presence.values()
] ]
@ -543,19 +577,17 @@ class TransactionQueue(object):
if pending_pdus: if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) self._destination, len(pending_pdus))
if not pending_pdus and not pending_edus: if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", self._destination)
self.last_device_stream_id_by_dest[destination] = ( self._last_device_stream_id = device_stream_id
device_stream_id
)
return return
# END CRITICAL SECTION # END CRITICAL SECTION
success = yield self._send_new_transaction( success = yield self._transaction_sender.send_new_transaction(
destination, pending_pdus, pending_edus, self._destination, pending_pdus, pending_edus
) )
if success: if success:
sent_transactions_counter.inc() sent_transactions_counter.inc()
@ -565,23 +597,25 @@ class TransactionQueue(object):
# Remove the acknowledged device messages from the database # Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages # Only bother if we actually sent some device messages
if device_message_edus: if device_message_edus:
yield self.store.delete_device_msgs_for_remote( yield self._store.delete_device_msgs_for_remote(
destination, device_stream_id self._destination, device_stream_id
) )
logger.info("Marking as sent %r %r", destination, dev_list_id) logger.info(
yield self.store.mark_as_sent_devices_by_remote( "Marking as sent %r %r", self._destination, dev_list_id
destination, dev_list_id )
yield self._store.mark_as_sent_devices_by_remote(
self._destination, dev_list_id
) )
self.last_device_stream_id_by_dest[destination] = device_stream_id self._last_device_stream_id = device_stream_id
self.last_device_list_stream_id_by_dest[destination] = dev_list_id self._last_device_list_stream_id = dev_list_id
else: else:
break break
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.debug( logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - " "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now", "dropping transaction for now",
destination, self._destination,
datetime.datetime.fromtimestamp( datetime.datetime.fromtimestamp(
(e.retry_last_ts + e.retry_interval) / 1000.0 (e.retry_last_ts + e.retry_interval) / 1000.0
), ),
@ -591,51 +625,51 @@ class TransactionQueue(object):
except HttpResponseException as e: except HttpResponseException as e:
logger.warning( logger.warning(
"TX [%s] Received %d response to transaction: %s", "TX [%s] Received %d response to transaction: %s",
destination, e.code, e, self._destination, e.code, e,
) )
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warning("TX [%s] Failed to send transaction: %s", destination, e) logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e)
for p, _ in pending_pdus: for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id, logger.info("Failed to send event %s to %s", p.event_id,
destination) self._destination)
except Exception: except Exception:
logger.exception( logger.exception(
"TX [%s] Failed to send transaction", "TX [%s] Failed to send transaction",
destination, self._destination,
) )
for p, _ in pending_pdus: for p, _ in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id, logger.info("Failed to send event %s to %s", p.event_id,
destination) self._destination)
finally: finally:
# We want to be *very* sure we delete this after we stop processing # We want to be *very* sure we clear this after we stop processing
self.pending_transactions.pop(destination, None) self.transmission_loop_running = False
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_new_device_messages(self, destination): def _get_new_device_messages(self):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self.store.get_to_device_stream_token() to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote( contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id self._destination, last_device_stream_id, to_device_stream_id
) )
edus = [ edus = [
Edu( Edu(
origin=self.server_name, origin=self._server_name,
destination=destination, destination=self._destination,
edu_type="m.direct_to_device", edu_type="m.direct_to_device",
content=content, content=content,
) )
for content in contents for content in contents
] ]
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0) last_device_list = self._last_device_list_stream_id
now_stream_id, results = yield self.store.get_devices_by_remote( now_stream_id, results = yield self._store.get_devices_by_remote(
destination, last_device_list self._destination, last_device_list
) )
edus.extend( edus.extend(
Edu( Edu(
origin=self.server_name, origin=self._server_name,
destination=destination, destination=self._destination,
edu_type="m.device_list_update", edu_type="m.device_list_update",
content=content, content=content,
) )
@ -643,9 +677,25 @@ class TransactionQueue(object):
) )
defer.returnValue((edus, stream_id, now_stream_id)) defer.returnValue((edus, stream_id, now_stream_id))
class TransactionSender(object):
"""Helper class which handles building and sending transactions
shared between PerDestinationQueue objects
"""
def __init__(self, hs):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._store = hs.get_datastore()
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus): def send_new_transaction(self, destination, pending_pdus, pending_edus):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -669,9 +719,9 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()), origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id, transaction_id=txn_id,
origin=self.server_name, origin=self._server_name,
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
edus=edus, edus=edus,
@ -679,7 +729,7 @@ class TransactionQueue(object):
self._next_txn_id += 1 self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction) yield self._transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
@ -697,7 +747,7 @@ class TransactionQueue(object):
# keys work # keys work
def json_data_cb(): def json_data_cb():
data = transaction.get_dict() data = transaction.get_dict()
now = int(self.clock.time_msec()) now = int(self._clock.time_msec())
if "pdus" in data: if "pdus" in data:
for p in data["pdus"]: for p in data["pdus"]:
if "age_ts" in p: if "age_ts" in p:
@ -707,7 +757,7 @@ class TransactionQueue(object):
return data return data
try: try:
response = yield self.transport_layer.send_transaction( response = yield self._transport_layer.send_transaction(
transaction, json_data_cb transaction, json_data_cb
) )
code = 200 code = 200
@ -727,7 +777,7 @@ class TransactionQueue(object):
destination, txn_id, code destination, txn_id, code
) )
yield self.transaction_actions.delivered( yield self._transaction_actions.delivered(
transaction, code, response transaction, code, response
) )