Add type hints to the federation handler and server. (#9743)

This commit is contained in:
Patrick Cloke 2021-04-06 07:21:57 -04:00 committed by GitHub
parent e7b769aea1
commit d959d28730
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 95 deletions

1
changelog.d/9743.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to federation handler and server.

View file

@ -739,22 +739,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
):
ret = await self.handler.exchange_third_party_invite(
) -> None:
await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
return ret
async def on_exchange_third_party_invite_request(self, event_dict: Dict):
ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
return ret
async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
await self.handler.on_exchange_third_party_invite_request(event_dict)
async def check_server_matches_acl(self, server_name: str, room_id: str):
async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room
Args:
@ -878,7 +876,7 @@ class FederationHandlerRegistry:
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
):
) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@ -897,7 +895,7 @@ class FederationHandlerRegistry:
def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@ -915,15 +913,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str):
def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
"""Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict):
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence:
return

View file

@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request(content)
return 200, content
await self.handler.on_exchange_third_party_invite_request(content)
return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet):

View file

@ -21,7 +21,17 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import attr
from signedjson.key import decode_verify_key_bytes
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
async def on_receive_pdu(
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
origin (str): server which initiated the /send/ transaction. Will
origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
pdu: received PDU
sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
"""
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
) -> None:
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu
prevs (set(str)): List of event ids which we are missing
min_depth (int): Minimum depth of events to return.
prevs: List of event ids which we are missing
min_depth: Minimum depth of events to return.
"""
room_id = pdu.room_id
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
):
) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender)
@log_function
async def backfill(self, dest, room_id, limit, extremities):
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
state (dict[tuple, FrozenEvent]): State map from type/state
key to event.
state: State map from type/state key to event.
Returns:
list[tuple[str, int]]: Returns a list of servers with the
lowest depth of their joins. Sorted by lowest depth first.
Returns a list of servers with the lowest depth of their joins.
Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
async def try_backfill(domains):
async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
[
dom
for dom, _ in likely_extremeties_domains
if dom not in tried_domains
]
)
if success:
return True
tried_domains.update(dom for dom, _ in likely_domains)
tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False
async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
) -> None:
"""Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos,
)
def _sanity_check_event(self, ev):
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches.
Args:
ev (synapse.events.EventBase): event to be checked
Returns: None
ev: event to be checked
Raises:
SynapseError if the event does not pass muster
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event):
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
async def _handle_queued_pdus(self, room_queue):
async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]]
) -> None:
"""Process PDUs which got queued up while we were busy send_joining.
Args:
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
and the servers that sent them
room_queue: list of PDUs to be processed and the servers that sent them
"""
for p, origin in room_queue:
try:
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_join_request(self, origin, pdu):
async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
):
) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_leave_request(self, origin, pdu):
async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it."""
event = pdu
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else:
return None
async def get_min_depth_for_context(self, context):
async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False
):
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
):
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
await self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
return ret
async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit
):
self,
origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth
Params:
local_auth (list)
remote_auth (list)
local_auth
remote_auth
Returns:
dict
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function
async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
) -> None:
third_party_invite = {"signed": signed}
event_dict = {
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context
):
self,
room_version: str,
event_dict: JsonDict,
event: EventBase,
context: EventContext,
) -> Tuple[EventBase, EventContext]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config)
return (event, context)
async def _check_signature(self, event, context):
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
Checks that the signature in the event is consistent with its invite.
Args:
event (Event): The m.room.member event to check
context (EventContext):
event: The m.room.member event to check
context:
Raises:
AuthError: if signature didn't match any keys, or key has been
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception
async def _check_key_revocation(self, public_key, url):
async def _check_key_revocation(self, public_key: str, url: str) -> None:
"""
Checks whether public_key has been revoked.
Args:
public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
public_key: base-64 encoded public key.
url: Key revocation URL.
Raises:
AuthError: if they key has been revoked.