forked from MirrorHub/synapse
Add type hints to the federation handler and server. (#9743)
This commit is contained in:
parent
e7b769aea1
commit
d959d28730
4 changed files with 97 additions and 95 deletions
1
changelog.d/9743.misc
Normal file
1
changelog.d/9743.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add missing type hints to federation handler and server.
|
|
@ -739,22 +739,20 @@ class FederationServer(FederationBase):
|
|||
|
||||
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
|
||||
):
|
||||
ret = await self.handler.exchange_third_party_invite(
|
||||
) -> None:
|
||||
await self.handler.exchange_third_party_invite(
|
||||
sender_user_id, target_user_id, room_id, signed
|
||||
)
|
||||
return ret
|
||||
|
||||
async def on_exchange_third_party_invite_request(self, event_dict: Dict):
|
||||
ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
|
||||
return ret
|
||||
async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
|
||||
await self.handler.on_exchange_third_party_invite_request(event_dict)
|
||||
|
||||
async def check_server_matches_acl(self, server_name: str, room_id: str):
|
||||
async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
|
||||
"""Check if the given server is allowed by the server ACLs in the room
|
||||
|
||||
Args:
|
||||
|
@ -878,7 +876,7 @@ class FederationHandlerRegistry:
|
|||
|
||||
def register_edu_handler(
|
||||
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
|
||||
):
|
||||
) -> None:
|
||||
"""Sets the handler callable that will be used to handle an incoming
|
||||
federation EDU of the given type.
|
||||
|
||||
|
@ -897,7 +895,7 @@ class FederationHandlerRegistry:
|
|||
|
||||
def register_query_handler(
|
||||
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
):
|
||||
) -> None:
|
||||
"""Sets the handler callable that will be used to handle an incoming
|
||||
federation query of the given type.
|
||||
|
||||
|
@ -915,15 +913,17 @@ class FederationHandlerRegistry:
|
|||
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
||||
def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
|
||||
"""Register that the EDU handler is on a different instance than master."""
|
||||
self._edu_type_to_instance[edu_type] = [instance_name]
|
||||
|
||||
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
|
||||
def register_instances_for_edu(
|
||||
self, edu_type: str, instance_names: List[str]
|
||||
) -> None:
|
||||
"""Register that the EDU handler is on multiple instances."""
|
||||
self._edu_type_to_instance[edu_type] = instance_names
|
||||
|
||||
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
||||
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
|
||||
if not self.config.use_presence and edu_type == EduTypes.Presence:
|
||||
return
|
||||
|
||||
|
|
|
@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
|||
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
|
||||
|
||||
async def on_PUT(self, origin, content, query, room_id):
|
||||
content = await self.handler.on_exchange_third_party_invite_request(content)
|
||||
return 200, content
|
||||
await self.handler.on_exchange_third_party_invite_request(content)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||
|
|
|
@ -21,7 +21,17 @@ import itertools
|
|||
import logging
|
||||
from collections.abc import Container
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
async def on_receive_pdu(
|
||||
self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
|
||||
) -> None:
|
||||
"""Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
Args:
|
||||
origin (str): server which initiated the /send/ transaction. Will
|
||||
origin: server which initiated the /send/ transaction. Will
|
||||
be used to fetch missing events or state.
|
||||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
pdu: received PDU
|
||||
sent_to_us_directly: True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
"""
|
||||
|
||||
|
@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
await self._process_received_pdu(origin, pdu, state=state)
|
||||
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
async def _get_missing_events_for_pdu(
|
||||
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
origin: Origin of the pdu. Will be called to get the missing events
|
||||
pdu: received pdu
|
||||
prevs (set(str)): List of event ids which we are missing
|
||||
min_depth (int): Minimum depth of events to return.
|
||||
prevs: List of event ids which we are missing
|
||||
min_depth: Minimum depth of events to return.
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
|
@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
|
|||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
):
|
||||
) -> None:
|
||||
"""Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
|
||||
|
@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
|
|||
logger.exception("Failed to resync device for %s", sender)
|
||||
|
||||
@log_function
|
||||
async def backfill(self, dest, room_id, limit, extremities):
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: List[str]
|
||||
) -> List[EventBase]:
|
||||
"""Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. If the other side
|
||||
|
@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
|
||||
"""Get joined domains from state
|
||||
|
||||
Args:
|
||||
state (dict[tuple, FrozenEvent]): State map from type/state
|
||||
key to event.
|
||||
state: State map from type/state key to event.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: Returns a list of servers with the
|
||||
lowest depth of their joins. Sorted by lowest depth first.
|
||||
Returns a list of servers with the lowest depth of their joins.
|
||||
Sorted by lowest depth first.
|
||||
"""
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
|
@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
|
|||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
async def try_backfill(domains):
|
||||
async def try_backfill(domains: List[str]) -> bool:
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
|
@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
|
|||
}
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
likely_extremeties_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = await try_backfill(
|
||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
||||
[
|
||||
dom
|
||||
for dom, _ in likely_extremeties_domains
|
||||
if dom not in tried_domains
|
||||
]
|
||||
)
|
||||
if success:
|
||||
return True
|
||||
|
||||
tried_domains.update(dom for dom, _ in likely_domains)
|
||||
tried_domains.update(dom for dom, _ in likely_extremeties_domains)
|
||||
|
||||
return False
|
||||
|
||||
async def _get_events_and_persist(
|
||||
self, destination: str, room_id: str, events: Iterable[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Fetch the given events from a server, and persist them as outliers.
|
||||
|
||||
This function *does not* recursively get missing auth events of the
|
||||
|
@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
|
|||
event_infos,
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev):
|
||||
def _sanity_check_event(self, ev: EventBase) -> None:
|
||||
"""
|
||||
Do some early sanity checks of a received event
|
||||
|
||||
|
@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
|
|||
or cascade of event fetches.
|
||||
|
||||
Args:
|
||||
ev (synapse.events.EventBase): event to be checked
|
||||
|
||||
Returns: None
|
||||
ev: event to be checked
|
||||
|
||||
Raises:
|
||||
SynapseError if the event does not pass muster
|
||||
|
@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
|
||||
|
||||
async def send_invite(self, target_host, event):
|
||||
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
|
||||
"""Sends the invite to the remote server for signing.
|
||||
|
||||
Invites must be signed by the invitee's server before distribution.
|
||||
|
@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
run_in_background(self._handle_queued_pdus, room_queue)
|
||||
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
async def _handle_queued_pdus(
|
||||
self, room_queue: List[Tuple[EventBase, str]]
|
||||
) -> None:
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
|
||||
and the servers that sent them
|
||||
room_queue: list of PDUs to be processed and the servers that sent them
|
||||
"""
|
||||
for p, origin in room_queue:
|
||||
try:
|
||||
|
@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return event
|
||||
|
||||
async def on_send_join_request(self, origin, pdu):
|
||||
async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
|
||||
"""We have received a join event for a room. Fully process it and
|
||||
respond with the current state and auth chains.
|
||||
"""
|
||||
|
@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
async def on_invite_request(
|
||||
self, origin: str, event: EventBase, room_version: RoomVersion
|
||||
):
|
||||
) -> EventBase:
|
||||
"""We've got an invite event. Process and persist it. Sign it.
|
||||
|
||||
Respond with the now signed event.
|
||||
|
@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return event
|
||||
|
||||
async def on_send_leave_request(self, origin, pdu):
|
||||
async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
|
||||
""" We have received a leave event for a room. Fully process it."""
|
||||
event = pdu
|
||||
|
||||
|
@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
return None
|
||||
|
||||
async def get_min_depth_for_context(self, context):
|
||||
async def get_min_depth_for_context(self, context: str) -> int:
|
||||
return await self.store.get_min_depth(context)
|
||||
|
||||
async def _handle_new_event(
|
||||
self, origin, event, state=None, auth_events=None, backfilled=False
|
||||
):
|
||||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]] = None,
|
||||
auth_events: Optional[MutableStateMap[EventBase]] = None,
|
||||
backfilled: bool = False,
|
||||
) -> EventContext:
|
||||
context = await self._prep_event(
|
||||
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
|
||||
)
|
||||
|
@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
|
|||
logger.warning("Soft-failing %r because %s", event, e)
|
||||
event.internal_metadata.soft_failed = True
|
||||
|
||||
async def on_query_auth(
|
||||
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
|
||||
):
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
for e in remote_auth_chain:
|
||||
try:
|
||||
await self._handle_new_event(origin, e)
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
# Now get the current auth_chain for the event.
|
||||
local_auth_chain = await self.store.get_auth_chain(
|
||||
room_id, list(event.auth_event_ids()), include_given=True
|
||||
)
|
||||
|
||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||
# everyone.
|
||||
|
||||
ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
|
||||
|
||||
logger.debug("on_query_auth returning: %s", ret)
|
||||
|
||||
return ret
|
||||
|
||||
async def on_get_missing_events(
|
||||
self, origin, room_id, earliest_events, latest_events, limit
|
||||
):
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[EventBase]:
|
||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
|
|||
assumes that we have already processed all events in remote_auth
|
||||
|
||||
Params:
|
||||
local_auth (list)
|
||||
remote_auth (list)
|
||||
local_auth
|
||||
remote_auth
|
||||
|
||||
Returns:
|
||||
dict
|
||||
|
@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id, target_user_id, room_id, signed
|
||||
):
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
|
||||
) -> None:
|
||||
third_party_invite = {"signed": signed}
|
||||
|
||||
event_dict = {
|
||||
|
@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
|
|||
await member_handler.send_membership_event(None, event, context)
|
||||
|
||||
async def add_display_name_to_third_party_invite(
|
||||
self, room_version, event_dict, event, context
|
||||
):
|
||||
self,
|
||||
room_version: str,
|
||||
event_dict: JsonDict,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
key = (
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
|
@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
|
|||
EventValidator().validate_new(event, self.config)
|
||||
return (event, context)
|
||||
|
||||
async def _check_signature(self, event, context):
|
||||
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
|
||||
"""
|
||||
Checks that the signature in the event is consistent with its invite.
|
||||
|
||||
Args:
|
||||
event (Event): The m.room.member event to check
|
||||
context (EventContext):
|
||||
event: The m.room.member event to check
|
||||
context:
|
||||
|
||||
Raises:
|
||||
AuthError: if signature didn't match any keys, or key has been
|
||||
|
@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
raise last_exception
|
||||
|
||||
async def _check_key_revocation(self, public_key, url):
|
||||
async def _check_key_revocation(self, public_key: str, url: str) -> None:
|
||||
"""
|
||||
Checks whether public_key has been revoked.
|
||||
|
||||
Args:
|
||||
public_key (str): base-64 encoded public key.
|
||||
url (str): Key revocation URL.
|
||||
public_key: base-64 encoded public key.
|
||||
url: Key revocation URL.
|
||||
|
||||
Raises:
|
||||
AuthError: if they key has been revoked.
|
||||
|
|
Loading…
Reference in a new issue