Various improvements to the federation client. (#9129)

* Type hints for `FederationClient`.
* Using `async` functions instead of returning `Awaitable` instances.
This commit is contained in:
Patrick Cloke 2021-01-20 07:59:18 -05:00 committed by GitHub
parent a5b9c87ac6
commit 620ecf13b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 58 deletions

1
changelog.d/9129.misc Normal file
View file

@ -0,0 +1 @@
Various improvements to the federation client.

View file

@ -18,6 +18,7 @@ import copy
import itertools import itertools
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@ -26,7 +27,6 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.pdu_destination_tried = {} self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
@ -116,33 +119,32 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict self.pdu_destination_tried[event_id] = destination_dict
@log_function @log_function
def make_query( async def make_query(
self, self,
destination, destination: str,
query_type, query_type: str,
args, args: dict,
retry_on_dns_fail=False, retry_on_dns_fail: bool = False,
ignore_backoff=False, ignore_backoff: bool = False,
): ) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
query_type (str): Category of the query type; should match the query_type: Category of the query type; should match the
handler name used in register_query_handler(). handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details args: Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff: true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
Returns: Returns:
a Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels(query_type).inc() sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query( return await self.transport_layer.make_query(
destination, destination,
query_type, query_type,
args, args,
@ -151,42 +153,52 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content, timeout): async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
content (dict): The query content. content: The query content.
Returns: Returns:
an Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels("client_device_keys").inc() sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(destination, content, timeout) return await self.transport_layer.query_client_keys(
destination, content, timeout
)
@log_function @log_function
def query_user_devices(self, destination, user_id, timeout=30000): async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
""" """
sent_queries_counter.labels("user_devices").inc() sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(destination, user_id, timeout) return await self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
destination (str): Domain name of the remote homeserver destination: Domain name of the remote homeserver
content (dict): The query content. content: The query content.
Returns: Returns:
an Awaitable which will eventually yield a JSON object from the The JSON object from the response
response
""" """
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(destination, content, timeout) return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@ -195,10 +207,10 @@ class FederationClient(FederationBase):
given destination server. given destination server.
Args: Args:
dest (str): The remote homeserver to ask. dest: The remote homeserver to ask.
room_id (str): The room_id to backfill. room_id: The room_id to backfill.
limit (int): The maximum number of events to return. limit: The maximum number of events to return.
extremities (list): our current backwards extremities, to backfill from extremities: our current backwards extremities, to backfill from
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
@ -370,7 +382,7 @@ class FederationClient(FederationBase):
for events that have failed their checks for events that have failed their checks
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(room_version, pdus) deferreds = self._check_sigs_and_hashes(room_version, pdus)
@ -418,7 +430,9 @@ class FederationClient(FederationBase):
else: else:
return [p for p in valid_pdus if p] return [p for p in valid_pdus if p]
async def get_event_auth(self, destination, room_id, event_id): async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id) res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
@ -700,18 +714,16 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)
async def _do_send_join(self, destination: str, pdu: EventBase): async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_join_v2( return await self.transport_layer.send_join_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -769,7 +781,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_invite_v2( return await self.transport_layer.send_invite_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -779,7 +791,6 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []), "invite_room_state": pdu.unsigned.get("invite_room_state", []),
}, },
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -842,18 +853,16 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request "send_leave", destinations, send_request
) )
async def _do_send_leave(self, destination, pdu): async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = await self.transport_layer.send_leave_v2( return await self.transport_layer.send_leave_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@ -879,7 +888,7 @@ class FederationClient(FederationBase):
# content. # content.
return resp[1] return resp[1]
def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: Optional[int] = None,
@ -887,7 +896,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None, search_filter: Optional[Dict] = None,
include_all_networks: bool = False, include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None, third_party_instance_id: Optional[str] = None,
): ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver """Get the list of public rooms from a remote homeserver
Args: Args:
@ -901,8 +910,7 @@ class FederationClient(FederationBase):
party instance party instance
Returns: Returns:
Awaitable[Dict[str, Any]]: The response from the remote server, or None if The response from the remote server.
`remote_server` is the same as the local server_name
Raises: Raises:
HttpResponseException: There was an exception returned from the remote server HttpResponseException: There was an exception returned from the remote server
@ -910,7 +918,7 @@ class FederationClient(FederationBase):
requests over federation requests over federation
""" """
return self.transport_layer.get_public_rooms( return await self.transport_layer.get_public_rooms(
remote_server, remote_server,
limit, limit,
since_token, since_token,
@ -923,7 +931,7 @@ class FederationClient(FederationBase):
self, self,
destination: str, destination: str,
room_id: str, room_id: str,
earliest_events_ids: Sequence[str], earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase], latest_events: Iterable[EventBase],
limit: int, limit: int,
min_depth: int, min_depth: int,
@ -974,7 +982,9 @@ class FederationClient(FederationBase):
return signed_events return signed_events
async def forward_third_party_invite(self, destinations, room_id, event_dict): async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -983,7 +993,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite( await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict destination=destination, room_id=room_id, event_dict=event_dict
) )
return None return
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
async def get_room_complexity( async def get_room_complexity(
self, destination: str, room_id: str self, destination: str, room_id: str
) -> Optional[dict]: ) -> Optional[JsonDict]:
""" """
Fetch the complexity of a remote room from another server. Fetch the complexity of a remote room from another server.
@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
could not fetch the complexity. could not fetch the complexity.
""" """
try: try:
complexity = await self.transport_layer.get_room_complexity( return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id destination=destination, room_id=room_id
) )
return complexity
except CodeMessageException as e: except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other # We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us. # servers don't give it to us.