forked from MirrorHub/synapse
Various improvements to the federation client. (#9129)
* Type hints for `FederationClient`. * Using `async` functions instead of returning `Awaitable` instances.
This commit is contained in:
parent
a5b9c87ac6
commit
620ecf13b0
2 changed files with 68 additions and 58 deletions
1
changelog.d/9129.misc
Normal file
1
changelog.d/9129.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Various improvements to the federation client.
|
|
@ -18,6 +18,7 @@ import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -26,7 +27,6 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
||||||
|
@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.pdu_destination_tried = {}
|
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
|
||||||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.transport_layer = hs.get_federation_transport_client()
|
self.transport_layer = hs.get_federation_transport_client()
|
||||||
|
@ -116,33 +119,32 @@ class FederationClient(FederationBase):
|
||||||
self.pdu_destination_tried[event_id] = destination_dict
|
self.pdu_destination_tried[event_id] = destination_dict
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def make_query(
|
async def make_query(
|
||||||
self,
|
self,
|
||||||
destination,
|
destination: str,
|
||||||
query_type,
|
query_type: str,
|
||||||
args,
|
args: dict,
|
||||||
retry_on_dns_fail=False,
|
retry_on_dns_fail: bool = False,
|
||||||
ignore_backoff=False,
|
ignore_backoff: bool = False,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Sends a federation Query to a remote homeserver of the given type
|
"""Sends a federation Query to a remote homeserver of the given type
|
||||||
and arguments.
|
and arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
query_type (str): Category of the query type; should match the
|
query_type: Category of the query type; should match the
|
||||||
handler name used in register_query_handler().
|
handler name used in register_query_handler().
|
||||||
args (dict): Mapping of strings to strings containing the details
|
args: Mapping of strings to strings containing the details
|
||||||
of the query request.
|
of the query request.
|
||||||
ignore_backoff (bool): true to ignore the historical backoff data
|
ignore_backoff: true to ignore the historical backoff data
|
||||||
and try the request anyway.
|
and try the request anyway.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels(query_type).inc()
|
sent_queries_counter.labels(query_type).inc()
|
||||||
|
|
||||||
return self.transport_layer.make_query(
|
return await self.transport_layer.make_query(
|
||||||
destination,
|
destination,
|
||||||
query_type,
|
query_type,
|
||||||
args,
|
args,
|
||||||
|
@ -151,42 +153,52 @@ class FederationClient(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def query_client_keys(self, destination, content, timeout):
|
async def query_client_keys(
|
||||||
|
self, destination: str, content: JsonDict, timeout: int
|
||||||
|
) -> JsonDict:
|
||||||
"""Query device keys for a device hosted on a remote server.
|
"""Query device keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
content (dict): The query content.
|
content: The query content.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
an Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("client_device_keys").inc()
|
sent_queries_counter.labels("client_device_keys").inc()
|
||||||
return self.transport_layer.query_client_keys(destination, content, timeout)
|
return await self.transport_layer.query_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
async def query_user_devices(
|
||||||
|
self, destination: str, user_id: str, timeout: int = 30000
|
||||||
|
) -> JsonDict:
|
||||||
"""Query the device keys for a list of user ids hosted on a remote
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
server.
|
server.
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("user_devices").inc()
|
sent_queries_counter.labels("user_devices").inc()
|
||||||
return self.transport_layer.query_user_devices(destination, user_id, timeout)
|
return await self.transport_layer.query_user_devices(
|
||||||
|
destination, user_id, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, content, timeout):
|
async def claim_client_keys(
|
||||||
|
self, destination: str, content: JsonDict, timeout: int
|
||||||
|
) -> JsonDict:
|
||||||
"""Claims one-time keys for a device hosted on a remote server.
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
content (dict): The query content.
|
content: The query content.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
an Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||||
return self.transport_layer.claim_client_keys(destination, content, timeout)
|
return await self.transport_layer.claim_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def backfill(
|
async def backfill(
|
||||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||||
|
@ -195,10 +207,10 @@ class FederationClient(FederationBase):
|
||||||
given destination server.
|
given destination server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dest (str): The remote homeserver to ask.
|
dest: The remote homeserver to ask.
|
||||||
room_id (str): The room_id to backfill.
|
room_id: The room_id to backfill.
|
||||||
limit (int): The maximum number of events to return.
|
limit: The maximum number of events to return.
|
||||||
extremities (list): our current backwards extremities, to backfill from
|
extremities: our current backwards extremities, to backfill from
|
||||||
"""
|
"""
|
||||||
logger.debug("backfill extrem=%s", extremities)
|
logger.debug("backfill extrem=%s", extremities)
|
||||||
|
|
||||||
|
@ -370,7 +382,7 @@ class FederationClient(FederationBase):
|
||||||
for events that have failed their checks
|
for events that have failed their checks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
A list of PDUs that have valid signatures and hashes.
|
||||||
"""
|
"""
|
||||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||||
|
|
||||||
|
@ -418,7 +430,9 @@ class FederationClient(FederationBase):
|
||||||
else:
|
else:
|
||||||
return [p for p in valid_pdus if p]
|
return [p for p in valid_pdus if p]
|
||||||
|
|
||||||
async def get_event_auth(self, destination, room_id, event_id):
|
async def get_event_auth(
|
||||||
|
self, destination: str, room_id: str, event_id: str
|
||||||
|
) -> List[EventBase]:
|
||||||
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
@ -700,18 +714,16 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
return await self._try_destination_list("send_join", destinations, send_request)
|
return await self._try_destination_list("send_join", destinations, send_request)
|
||||||
|
|
||||||
async def _do_send_join(self, destination: str, pdu: EventBase):
|
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_join_v2(
|
return await self.transport_layer.send_join_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -769,7 +781,7 @@ class FederationClient(FederationBase):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_invite_v2(
|
return await self.transport_layer.send_invite_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
|
@ -779,7 +791,6 @@ class FederationClient(FederationBase):
|
||||||
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -842,18 +853,16 @@ class FederationClient(FederationBase):
|
||||||
"send_leave", destinations, send_request
|
"send_leave", destinations, send_request
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_send_leave(self, destination, pdu):
|
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_leave_v2(
|
return await self.transport_layer.send_leave_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -879,7 +888,7 @@ class FederationClient(FederationBase):
|
||||||
# content.
|
# content.
|
||||||
return resp[1]
|
return resp[1]
|
||||||
|
|
||||||
def get_public_rooms(
|
async def get_public_rooms(
|
||||||
self,
|
self,
|
||||||
remote_server: str,
|
remote_server: str,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
|
@ -887,7 +896,7 @@ class FederationClient(FederationBase):
|
||||||
search_filter: Optional[Dict] = None,
|
search_filter: Optional[Dict] = None,
|
||||||
include_all_networks: bool = False,
|
include_all_networks: bool = False,
|
||||||
third_party_instance_id: Optional[str] = None,
|
third_party_instance_id: Optional[str] = None,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Get the list of public rooms from a remote homeserver
|
"""Get the list of public rooms from a remote homeserver
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -901,8 +910,7 @@ class FederationClient(FederationBase):
|
||||||
party instance
|
party instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
|
The response from the remote server.
|
||||||
`remote_server` is the same as the local server_name
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HttpResponseException: There was an exception returned from the remote server
|
HttpResponseException: There was an exception returned from the remote server
|
||||||
|
@ -910,7 +918,7 @@ class FederationClient(FederationBase):
|
||||||
requests over federation
|
requests over federation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.transport_layer.get_public_rooms(
|
return await self.transport_layer.get_public_rooms(
|
||||||
remote_server,
|
remote_server,
|
||||||
limit,
|
limit,
|
||||||
since_token,
|
since_token,
|
||||||
|
@ -923,7 +931,7 @@ class FederationClient(FederationBase):
|
||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
earliest_events_ids: Sequence[str],
|
earliest_events_ids: Iterable[str],
|
||||||
latest_events: Iterable[EventBase],
|
latest_events: Iterable[EventBase],
|
||||||
limit: int,
|
limit: int,
|
||||||
min_depth: int,
|
min_depth: int,
|
||||||
|
@ -974,7 +982,9 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
return signed_events
|
return signed_events
|
||||||
|
|
||||||
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
async def forward_third_party_invite(
|
||||||
|
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||||
|
) -> None:
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
if destination == self.server_name:
|
if destination == self.server_name:
|
||||||
continue
|
continue
|
||||||
|
@ -983,7 +993,7 @@ class FederationClient(FederationBase):
|
||||||
await self.transport_layer.exchange_third_party_invite(
|
await self.transport_layer.exchange_third_party_invite(
|
||||||
destination=destination, room_id=room_id, event_dict=event_dict
|
destination=destination, room_id=room_id, event_dict=event_dict
|
||||||
)
|
)
|
||||||
return None
|
return
|
||||||
except CodeMessageException:
|
except CodeMessageException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
async def get_room_complexity(
|
async def get_room_complexity(
|
||||||
self, destination: str, room_id: str
|
self, destination: str, room_id: str
|
||||||
) -> Optional[dict]:
|
) -> Optional[JsonDict]:
|
||||||
"""
|
"""
|
||||||
Fetch the complexity of a remote room from another server.
|
Fetch the complexity of a remote room from another server.
|
||||||
|
|
||||||
|
@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
|
||||||
could not fetch the complexity.
|
could not fetch the complexity.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
complexity = await self.transport_layer.get_room_complexity(
|
return await self.transport_layer.get_room_complexity(
|
||||||
destination=destination, room_id=room_id
|
destination=destination, room_id=room_id
|
||||||
)
|
)
|
||||||
return complexity
|
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
# We didn't manage to get it -- probably a 404. We are okay if other
|
# We didn't manage to get it -- probably a 404. We are okay if other
|
||||||
# servers don't give it to us.
|
# servers don't give it to us.
|
||||||
|
|
Loading…
Reference in a new issue