0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-15 00:58:22 +02:00

Add typing to synapse.federation.sender (#6871)

This commit is contained in:
Erik Johnston 2020-02-07 13:56:38 +00:00 committed by GitHub
parent de2d267375
commit b08b0a22d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 138 additions and 107 deletions

1
changelog.d/6871.misc Normal file
View file

@ -0,0 +1 @@
Add typing to `synapse.federation.sender` and port to async/await.

View file

@ -294,7 +294,12 @@ class FederationServer(FederationBase):
async def _process_edu(edu_dict):
received_edus_counter.inc()
edu = Edu(**edu_dict)
edu = Edu(
origin=origin,
destination=self.server_name,
edu_type=edu_dict["edu_type"],
content=edu_dict["content"],
)
await self.registry.on_edu(edu.edu_type, origin, edu.content)
await concurrently_execute(

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set
from six import itervalues
@ -23,6 +24,7 @@ from twisted.internet import defer
import synapse
import synapse.metrics
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
@ -39,6 +41,8 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@ -68,7 +72,7 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
# map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
@ -84,7 +88,7 @@ class FederationSender(object):
# Map of user_id -> UserPresenceState for all the pending presence
# to be sent out by user_id. Entries here get processed and put in
# pending_presence_by_dest
self.pending_presence = {}
self.pending_presence = {} # type: Dict[str, UserPresenceState]
LaterGauge(
"synapse_federation_transaction_queue_pending_pdus",
@ -116,20 +120,17 @@ class FederationSender(object):
# and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room = (
{}
) # type: dict[str, set[PerDestinationQueue]]
) # type: Dict[str, Set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = (
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
)
def _get_per_destination_queue(self, destination):
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
Args:
destination (str): server_name of remote server
Returns:
PerDestinationQueue
destination: server_name of remote server
"""
queue = self._per_destination_queues.get(destination)
if not queue:
@ -137,7 +138,7 @@ class FederationSender(object):
self._per_destination_queues[destination] = queue
return queue
def notify_new_events(self, current_id):
def notify_new_events(self, current_id: int) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.
"""
@ -151,13 +152,12 @@ class FederationSender(object):
"process_event_queue_for_federation", self._process_event_queue_loop
)
@defer.inlineCallbacks
def _process_event_queue_loop(self):
async def _process_event_queue_loop(self) -> None:
try:
self._is_processing = True
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
last_token = await self.store.get_federation_out_pos("events")
next_token, events = await self.store.get_all_new_events_stream(
last_token, self._last_poked_id, limit=100
)
@ -166,8 +166,7 @@ class FederationSender(object):
if not events and next_token >= self._last_poked_id:
break
@defer.inlineCallbacks
def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender)
@ -184,7 +183,7 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_hosts_in_room_at_events(
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
@ -206,17 +205,16 @@ class FederationSender(object):
self._send_pdu(event, destinations)
@defer.inlineCallbacks
def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
with Measure(self.clock, "handle_room_events"):
for event in events:
yield handle_event(event)
await handle_event(event)
events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
@ -226,11 +224,11 @@ class FederationSender(object):
)
)
yield self.store.update_federation_out_pos("events", next_token)
await self.store.update_federation_out_pos("events", next_token)
if events:
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_lag.labels(
"federation_sender"
@ -254,7 +252,7 @@ class FederationSender(object):
finally:
self._is_processing = False
def _send_pdu(self, pdu, destinations):
def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
@ -276,11 +274,11 @@ class FederationSender(object):
self._get_per_destination_queue(destination).send_pdu(pdu, order)
@defer.inlineCallbacks
def send_read_receipt(self, receipt):
def send_read_receipt(self, receipt: ReadReceipt):
"""Send a RR to any other servers in the room
Args:
receipt (synapse.types.ReadReceipt): receipt to be sent
receipt: receipt to be sent
"""
# Some background on the rate-limiting going on here.
@ -343,7 +341,7 @@ class FederationSender(object):
else:
queue.flush_read_receipts_for_room(room_id)
def _schedule_rr_flush_for_room(self, room_id, n_domains):
def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None:
# that is going to cause approximately len(domains) transactions, so now back
# off for that multiplied by RR_TXN_INTERVAL_PER_ROOM
backoff_ms = self._rr_txn_interval_per_room_ms * n_domains
@ -352,7 +350,7 @@ class FederationSender(object):
self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id)
self._queues_awaiting_rr_flush_by_room[room_id] = set()
def _flush_rrs_for_room(self, room_id):
def _flush_rrs_for_room(self, room_id: str) -> None:
queues = self._queues_awaiting_rr_flush_by_room.pop(room_id)
logger.debug("Flushing RRs in %s to %s", room_id, queues)
@ -368,14 +366,11 @@ class FederationSender(object):
@preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
triggers a background task to process them and send out the transactions.
Args:
states (list(UserPresenceState))
"""
if not self.hs.config.use_presence:
# No-op if presence is disabled.
@ -412,11 +407,10 @@ class FederationSender(object):
finally:
self._processing_pending_presence = False
def send_presence_to_destinations(self, states, destinations):
def send_presence_to_destinations(
self, states: List[UserPresenceState], destinations: List[str]
) -> None:
"""Send the given presence states to the given destinations.
Args:
states (list[UserPresenceState])
destinations (list[str])
"""
@ -431,12 +425,9 @@ class FederationSender(object):
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states):
def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
Args:
states (list(UserPresenceState))
"""
hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
@ -446,14 +437,20 @@ class FederationSender(object):
continue
self._get_per_destination_queue(destination).send_presence(states)
def build_and_send_edu(self, destination, edu_type, content, key=None):
def build_and_send_edu(
self,
destination: str,
edu_type: str,
content: dict,
key: Optional[Hashable] = None,
):
"""Construct an Edu object, and queue it for sending
Args:
destination (str): name of server to send to
edu_type (str): type of EDU to send
content (dict): content of EDU
key (Any|None): clobbering key for this edu
destination: name of server to send to
edu_type: type of EDU to send
content: content of EDU
key: clobbering key for this edu
"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
@ -468,12 +465,12 @@ class FederationSender(object):
self.send_edu(edu, key)
def send_edu(self, edu, key):
def send_edu(self, edu: Edu, key: Optional[Hashable]):
"""Queue an EDU for sending
Args:
edu (Edu): edu to send
key (Any|None): clobbering key for this edu
edu: edu to send
key: clobbering key for this edu
"""
queue = self._get_per_destination_queue(edu.destination)
if key:
@ -481,7 +478,7 @@ class FederationSender(object):
else:
queue.send_edu(edu)
def send_device_messages(self, destination):
def send_device_messages(self, destination: str):
if destination == self.server_name:
logger.warning("Not sending device update to ourselves")
return
@ -501,5 +498,5 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self):
def get_current_token(self) -> int:
return 0

View file

@ -15,11 +15,11 @@
# limitations under the License.
import datetime
import logging
from typing import Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
from twisted.internet import defer
import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import StateMap
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@ -56,13 +56,18 @@ class PerDestinationQueue(object):
Manages the per-destination transmission queues.
Args:
hs (synapse.HomeServer):
transaction_sender (TransactionManager):
destination (str): the server_name of the destination that we are managing
hs
transaction_sender
destination: the server_name of the destination that we are managing
transmission for.
"""
def __init__(self, hs, transaction_manager, destination):
def __init__(
self,
hs: "synapse.server.HomeServer",
transaction_manager: "synapse.federation.sender.TransactionManager",
destination: str,
):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._store = hs.get_datastore()
@ -72,20 +77,20 @@ class PerDestinationQueue(object):
self.transmission_loop_running = False
# a list of tuples of (pending pdu, order)
self._pending_pdus = [] # type: list[tuple[EventBase, int]]
self._pending_edus = [] # type: list[Edu]
self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
self._pending_edus = [] # type: List[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: StateMap[Edu]
self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
self._pending_presence = {} # type: dict[str, UserPresenceState]
self._pending_presence = {} # type: Dict[str, UserPresenceState]
# room_id -> receipt_type -> user_id -> receipt_dict
self._pending_rrs = {}
self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]]
self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message.
@ -95,50 +100,50 @@ class PerDestinationQueue(object):
# stream_id of last successfully sent device list update.
self._last_device_list_stream_id = 0
def __str__(self):
def __str__(self) -> str:
return "PerDestinationQueue[%s]" % self._destination
def pending_pdu_count(self):
def pending_pdu_count(self) -> int:
return len(self._pending_pdus)
def pending_edu_count(self):
def pending_edu_count(self) -> int:
return (
len(self._pending_edus)
+ len(self._pending_presence)
+ len(self._pending_edus_keyed)
)
def send_pdu(self, pdu, order):
def send_pdu(self, pdu: EventBase, order: int) -> None:
"""Add a PDU to the queue, and start the transmission loop if neccessary
Args:
pdu (EventBase): pdu to send
order (int):
pdu: pdu to send
order
"""
self._pending_pdus.append((pdu, order))
self.attempt_new_transaction()
def send_presence(self, states):
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
"""Add presence updates to the queue. Start the transmission loop if neccessary.
Args:
states (iterable[UserPresenceState]): presence to send
states: presence to send
"""
self._pending_presence.update({state.user_id: state for state in states})
self.attempt_new_transaction()
def queue_read_receipt(self, receipt):
def queue_read_receipt(self, receipt: ReadReceipt) -> None:
"""Add a RR to the list to be sent. Doesn't start the transmission loop yet
(see flush_read_receipts_for_room)
Args:
receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
receipt: receipt to be queued
"""
self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
receipt.receipt_type, {}
)[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
def flush_read_receipts_for_room(self, room_id):
def flush_read_receipts_for_room(self, room_id: str) -> None:
# if we don't have any read-receipts for this room, it may be that we've already
# sent them out, so we don't need to flush.
if room_id not in self._pending_rrs:
@ -146,15 +151,15 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = True
self.attempt_new_transaction()
def send_keyed_edu(self, edu, key):
def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction()
def send_edu(self, edu):
def send_edu(self, edu) -> None:
self._pending_edus.append(edu)
self.attempt_new_transaction()
def attempt_new_transaction(self):
def attempt_new_transaction(self) -> None:
"""Try to start a new transaction to this destination
If there is already a transaction in progress to this destination,
@ -177,23 +182,22 @@ class PerDestinationQueue(object):
self._transaction_transmission_loop,
)
@defer.inlineCallbacks
def _transaction_transmission_loop(self):
pending_pdus = []
async def _transaction_transmission_loop(self) -> None:
pending_pdus = [] # type: List[Tuple[EventBase, int]]
try:
self.transmission_loop_running = True
# This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client,
# hence why we throw the result away.
yield get_retry_limiter(self._destination, self._clock, self._store)
await get_retry_limiter(self._destination, self._clock, self._store)
pending_pdus = []
while True:
# We have to keep 2 free slots for presence and rr_edus
limit = MAX_EDUS_PER_TRANSACTION - 2
device_update_edus, dev_list_id = yield self._get_device_update_edus(
device_update_edus, dev_list_id = await self._get_device_update_edus(
limit
)
@ -202,7 +206,7 @@ class PerDestinationQueue(object):
(
to_device_edus,
device_stream_id,
) = yield self._get_to_device_message_edus(limit)
) = await self._get_to_device_message_edus(limit)
pending_edus = device_update_edus + to_device_edus
@ -269,7 +273,7 @@ class PerDestinationQueue(object):
# END CRITICAL SECTION
success = yield self._transaction_manager.send_new_transaction(
success = await self._transaction_manager.send_new_transaction(
self._destination, pending_pdus, pending_edus
)
if success:
@ -280,7 +284,7 @@ class PerDestinationQueue(object):
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if to_device_edus:
yield self._store.delete_device_msgs_for_remote(
await self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id
)
@ -289,7 +293,7 @@ class PerDestinationQueue(object):
logger.info(
"Marking as sent %r %r", self._destination, dev_list_id
)
yield self._store.mark_as_sent_devices_by_remote(
await self._store.mark_as_sent_devices_by_remote(
self._destination, dev_list_id
)
@ -334,7 +338,7 @@ class PerDestinationQueue(object):
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
def _get_rr_edus(self, force_flush):
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
return
if not force_flush and not self._rrs_pending_flush:
@ -351,17 +355,16 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = False
yield edu
def _pop_pending_edus(self, limit):
def _pop_pending_edus(self, limit: int) -> List[Edu]:
pending_edus = self._pending_edus
pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
return pending_edus
@defer.inlineCallbacks
def _get_device_update_edus(self, limit):
async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_device_updates_by_remote(
now_stream_id, results = await self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
@ -378,11 +381,10 @@ class PerDestinationQueue(object):
return (edus, now_stream_id)
@defer.inlineCallbacks
def _get_to_device_message_edus(self, limit):
async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
contents, stream_id = await self._store.get_new_device_msgs_for_remote(
self._destination, last_device_stream_id, to_device_stream_id, limit
)
edus = [

View file

@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from canonicaljson import json
from twisted.internet import defer
import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Transaction
from synapse.federation.units import Edu, Transaction
from synapse.logging.opentracing import (
extract_text_map,
set_tag,
@ -39,7 +40,7 @@ class TransactionManager(object):
shared between PerDestinationQueue objects
"""
def __init__(self, hs):
def __init__(self, hs: "synapse.server.HomeServer"):
self._server_name = hs.hostname
self.clock = hs.get_clock() # nb must be called this for @measure_func
self._store = hs.get_datastore()
@ -50,8 +51,9 @@ class TransactionManager(object):
self._next_txn_id = int(self.clock.time_msec())
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def send_new_transaction(self, destination, pending_pdus, pending_edus):
async def send_new_transaction(
self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
):
# Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is
@ -127,7 +129,7 @@ class TransactionManager(object):
return data
try:
response = yield self._transport_layer.send_transaction(
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200

View file

@ -19,11 +19,15 @@ server protocol.
import logging
import attr
from synapse.types import JsonDict
from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
@attr.s(slots=True)
class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another.
@ -32,11 +36,24 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
valid_keys = ["origin", "destination", "edu_type", "content"]
edu_type = attr.ib(type=str)
content = attr.ib(type=dict)
origin = attr.ib(type=str)
destination = attr.ib(type=str)
required_keys = ["edu_type"]
def get_dict(self) -> JsonDict:
return {
"edu_type": self.edu_type,
"content": self.content,
}
internal_keys = ["origin", "destination"]
def get_internal_dict(self) -> JsonDict:
return {
"edu_type": self.edu_type,
"content": self.content,
"origin": self.origin,
"destination": self.destination,
}
def get_context(self):
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")

View file

@ -107,3 +107,5 @@ class HomeServer(object):
self,
) -> synapse.replication.tcp.client.ReplicationClientHandler:
pass
def is_mine_id(self, domain_id: str) -> bool:
pass

View file

@ -111,7 +111,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
self.datastore.get_device_updates_by_remote.return_value = (0, [])
self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
(0, [])
)
def get_received_txn_response(*args):
return defer.succeed(None)
@ -144,7 +146,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
None

View file

@ -179,6 +179,7 @@ extras = all
commands = mypy \
synapse/api \
synapse/config/ \
synapse/federation/sender \
synapse/federation/transport \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \