Experimental Federation Speedup (#9702)

This basically speeds up federation by "squeezing" each individual dual database call (to destinations and destination_rooms), which previously happened per every event, into one call for an entire batch (100 max).

Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
This commit is contained in:
Jonathan de Jong 2021-04-14 18:19:02 +02:00 committed by GitHub
parent 00a6db9676
commit 05e8c70c05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 127 additions and 95 deletions

1
changelog.d/9702.misc Normal file
View file

@ -0,0 +1 @@
Speed up federation transmission by using fewer database calls. Contributed by @ShadowJonathan.

View file

@ -224,7 +224,8 @@ class HomeServer(ReplicationHandler):
destinations = yield self.get_servers_for_context(room_name) destinations = yield self.get_servers_for_context(room_name)
try: try:
yield self.replication_layer.send_pdu( yield self.replication_layer.send_pdus(
[
Pdu.create_new( Pdu.create_new(
context=room_name, context=room_name,
pdu_type="sy.room.message", pdu_type="sy.room.message",
@ -232,6 +233,7 @@ class HomeServer(ReplicationHandler):
origin=self.server_name, origin=self.server_name,
destinations=destinations, destinations=destinations,
) )
]
) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@ -253,7 +255,7 @@ class HomeServer(ReplicationHandler):
origin=self.server_name, origin=self.server_name,
destinations=destinations, destinations=destinations,
) )
yield self.replication_layer.send_pdu(pdu) yield self.replication_layer.send_pdus([pdu])
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@ -265,7 +267,8 @@ class HomeServer(ReplicationHandler):
destinations = yield self.get_servers_for_context(room_name) destinations = yield self.get_servers_for_context(room_name)
try: try:
yield self.replication_layer.send_pdu( yield self.replication_layer.send_pdus(
[
Pdu.create_new( Pdu.create_new(
context=room_name, context=room_name,
is_state=True, is_state=True,
@ -275,6 +278,7 @@ class HomeServer(ReplicationHandler):
origin=self.server_name, origin=self.server_name,
destinations=destinations, destinations=destinations,
) )
]
) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)

View file

@ -18,8 +18,6 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set,
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics import synapse.metrics
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
@ -27,11 +25,7 @@ from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu from synapse.federation.units import Edu
from synapse.handlers.presence import get_interested_remotes from synapse.handlers.presence import get_interested_remotes
from synapse.logging.context import ( from synapse.logging.context import preserve_fn
make_deferred_yieldable,
preserve_fn,
run_in_background,
)
from synapse.metrics import ( from synapse.metrics import (
LaterGauge, LaterGauge,
event_processing_loop_counter, event_processing_loop_counter,
@ -39,7 +33,7 @@ from synapse.metrics import (
events_processed_counter, events_processed_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken from synapse.types import Collection, JsonDict, ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING: if TYPE_CHECKING:
@ -276,15 +270,27 @@ class FederationSender(AbstractFederationSender):
if not events and next_token >= self._last_poked_id: if not events and next_token >= self._last_poked_id:
break break
async def handle_event(event: EventBase) -> None: async def get_destinations_for_event(
event: EventBase,
) -> Collection[str]:
"""Computes the destinations to which this event must be sent.
This returns an empty tuple when there are no destinations to send to,
or if this event is not from this homeserver and it is not sending
it on behalf of another server.
Will also filter out destinations which this sender is not responsible for,
if multiple federation senders exist.
"""
# Only send events for this server. # Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender) is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None: if not is_mine and send_on_behalf_of is None:
return return ()
if not event.internal_metadata.should_proactively_send(): if not event.internal_metadata.should_proactively_send():
return return ()
destinations = None # type: Optional[Set[str]] destinations = None # type: Optional[Set[str]]
if not event.prev_event_ids(): if not event.prev_event_ids():
@ -319,7 +325,7 @@ class FederationSender(AbstractFederationSender):
"Failed to calculate hosts in room for event: %s", "Failed to calculate hosts in room for event: %s",
event.event_id, event.event_id,
) )
return return ()
destinations = { destinations = {
d d
@ -329,17 +335,15 @@ class FederationSender(AbstractFederationSender):
) )
} }
destinations.discard(self.server_name)
if send_on_behalf_of is not None: if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server # If we are sending the event on behalf of another server
# then it already has the event and there is no reason to # then it already has the event and there is no reason to
# send the event to it. # send the event to it.
destinations.discard(send_on_behalf_of) destinations.discard(send_on_behalf_of)
logger.debug("Sending %s to %r", event, destinations)
if destinations: if destinations:
await self._send_pdu(event, destinations)
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
@ -347,24 +351,29 @@ class FederationSender(AbstractFederationSender):
"federation_sender" "federation_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
async def handle_room_events(events: Iterable[EventBase]) -> None: return destinations
with Measure(self.clock, "handle_room_events"): return ()
for event in events:
await handle_event(event)
events_by_room = {} # type: Dict[str, List[EventBase]] async def get_federatable_events_and_destinations(
for event in events: events: Iterable[EventBase],
events_by_room.setdefault(event.room_id, []).append(event) ) -> List[Tuple[EventBase, Collection[str]]]:
with Measure(self.clock, "get_destinations_for_events"):
# Fetch federation destinations per event,
# skip if get_destinations_for_event returns an empty collection,
# return list of event->destinations pairs.
return [
(event, dests)
for (event, dests) in [
(event, await get_destinations_for_event(event))
for event in events
]
if dests
]
await make_deferred_yieldable( events_and_dests = await get_federatable_events_and_destinations(events)
defer.gatherResults(
[ # Send corresponding events to each destination queue
run_in_background(handle_room_events, evs) await self._distribute_events(events_and_dests)
for evs in events_by_room.values()
],
consumeErrors=True,
)
)
await self.store.update_federation_out_pos("events", next_token) await self.store.update_federation_out_pos("events", next_token)
@ -382,7 +391,7 @@ class FederationSender(AbstractFederationSender):
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(events))
event_processing_loop_room_count.labels("federation_sender").inc( event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room) len({event.room_id for event in events})
) )
event_processing_loop_counter.labels("federation_sender").inc() event_processing_loop_counter.labels("federation_sender").inc()
@ -394,34 +403,53 @@ class FederationSender(AbstractFederationSender):
finally: finally:
self._is_processing = False self._is_processing = False
async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: async def _distribute_events(
# We loop through all destinations to see whether we already have self,
# a transaction in progress. If we do, stick it in the pending_pdus events_and_dests: Iterable[Tuple[EventBase, Collection[str]]],
# table and we'll get back to it later. ) -> None:
"""Distribute events to the respective per_destination queues.
destinations = set(destinations) Also persists last-seen per-room stream_ordering to 'destination_rooms'.
destinations.discard(self.server_name)
logger.debug("Sending to: %s", str(destinations))
if not destinations: Args:
return events_and_dests: A list of tuples, which are (event: EventBase, destinations: Collection[str]).
Every event is paired with its intended destinations (in federation).
"""
# Tuples of room_id + destination to their max-seen stream_ordering
room_with_dest_stream_ordering = {} # type: Dict[Tuple[str, str], int]
# List of events to send to each destination
events_by_dest = {} # type: Dict[str, List[EventBase]]
# For each event-destinations pair...
for event, destinations in events_and_dests:
# (we got this from the database, it's filled)
assert event.internal_metadata.stream_ordering
sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc() sent_pdus_destination_dist_count.inc()
assert pdu.internal_metadata.stream_ordering # ...iterate over those destinations..
for destination in destinations:
# track the fact that we have a PDU for these destinations, # ...update their stream-ordering...
# to allow us to perform catch-up later on if the remote is unreachable room_with_dest_stream_ordering[(event.room_id, destination)] = max(
# for a while. event.internal_metadata.stream_ordering,
await self.store.store_destination_rooms_entries( room_with_dest_stream_ordering.get((event.room_id, destination), 0),
destinations,
pdu.room_id,
pdu.internal_metadata.stream_ordering,
) )
for destination in destinations: # ...and add the event to each destination queue.
self._get_per_destination_queue(destination).send_pdu(pdu) events_by_dest.setdefault(destination, []).append(event)
# Bulk-store destination_rooms stream_ids
await self.store.bulk_store_destination_rooms_entries(
room_with_dest_stream_ordering
)
for destination, pdus in events_by_dest.items():
logger.debug("Sending %d pdus to %s", len(pdus), destination)
self._get_per_destination_queue(destination).send_pdus(pdus)
async def send_read_receipt(self, receipt: ReadReceipt) -> None: async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room """Send a RR to any other servers in the room

View file

@ -154,19 +154,22 @@ class PerDestinationQueue:
+ len(self._pending_edus_keyed) + len(self._pending_edus_keyed)
) )
def send_pdu(self, pdu: EventBase) -> None: def send_pdus(self, pdus: Iterable[EventBase]) -> None:
"""Add a PDU to the queue, and start the transmission loop if necessary """Add PDUs to the queue, and start the transmission loop if necessary
Args: Args:
pdu: pdu to send pdus: pdus to send
""" """
if not self._catching_up or self._last_successful_stream_ordering is None: if not self._catching_up or self._last_successful_stream_ordering is None:
# only enqueue the PDU if we are not catching up (False) or do not # only enqueue the PDU if we are not catching up (False) or do not
# yet know if we have anything to catch up (None) # yet know if we have anything to catch up (None)
self._pending_pdus.append(pdu) self._pending_pdus.extend(pdus)
else: else:
assert pdu.internal_metadata.stream_ordering self._catchup_last_skipped = max(
self._catchup_last_skipped = pdu.internal_metadata.stream_ordering pdu.internal_metadata.stream_ordering
for pdu in pdus
if pdu.internal_metadata.stream_ordering is not None
)
self.attempt_new_transaction() self.attempt_new_transaction()

View file

@ -14,7 +14,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -295,37 +295,33 @@ class TransactionStore(TransactionWorkerStore):
}, },
) )
async def store_destination_rooms_entries( async def bulk_store_destination_rooms_entries(
self, self, room_and_destination_to_ordering: Dict[Tuple[str, str], int]
destinations: Iterable[str], ):
room_id: str,
stream_ordering: int,
) -> None:
""" """
Updates or creates `destination_rooms` entries in batch for a single event. Updates or creates `destination_rooms` entries for a number of events.
Args: Args:
destinations: list of destinations room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id
room_id: the room_id of the event
stream_ordering: the stream_ordering of the event
""" """
await self.db_pool.simple_upsert_many( await self.db_pool.simple_upsert_many(
table="destinations", table="destinations",
key_names=("destination",), key_names=("destination",),
key_values=[(d,) for d in destinations], key_values={(d,) for _, d in room_and_destination_to_ordering.keys()},
value_names=[], value_names=[],
value_values=[], value_values=[],
desc="store_destination_rooms_entries_dests", desc="store_destination_rooms_entries_dests",
) )
rows = [(destination, room_id) for destination in destinations]
await self.db_pool.simple_upsert_many( await self.db_pool.simple_upsert_many(
table="destination_rooms", table="destination_rooms",
key_names=("destination", "room_id"), key_names=("room_id", "destination"),
key_values=rows, key_values=list(room_and_destination_to_ordering.keys()),
value_names=["stream_ordering"], value_names=["stream_ordering"],
value_values=[(stream_ordering,)] * len(rows), value_values=[
(stream_id,) for stream_id in room_and_destination_to_ordering.values()
],
desc="store_destination_rooms_entries_rooms", desc="store_destination_rooms_entries_rooms",
) )