Add typing information to federation_server. (#7219)

This commit is contained in:
Patrick Cloke 2020-04-07 15:03:23 -04:00 committed by GitHub
parent ec5ac8e2b1
commit d78cb31588
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 65 deletions

1
changelog.d/7219.misc Normal file
View file

@ -0,0 +1 @@
Add typing information to federation server code.

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
import six import six
from six import iteritems from six import iteritems
@ -38,6 +38,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
@ -94,7 +95,9 @@ class FederationServer(FederationBase):
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
async def on_backfill_request(self, origin, room_id, versions, limit): async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))): with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -107,23 +110,25 @@ class FederationServer(FederationBase):
return 200, res return 200, res
async def on_incoming_transaction(self, origin, transaction_data): async def on_incoming_transaction(
self, origin: str, transaction_data: JsonDict
) -> Tuple[int, Dict[str, Any]]:
# keep this as early as possible to make the calculated origin ts as # keep this as early as possible to make the calculated origin ts as
# accurate as possible. # accurate as possible.
request_time = self._clock.time_msec() request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
if not transaction.transaction_id: if not transaction.transaction_id: # type: ignore
raise Exception("Transaction missing transaction_id") raise Exception("Transaction missing transaction_id")
logger.debug("[%s] Got transaction", transaction.transaction_id) logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
# use a linearizer to ensure that we don't process the same transaction # use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel. # multiple times in parallel.
with ( with (
await self._transaction_linearizer.queue( await self._transaction_linearizer.queue(
(origin, transaction.transaction_id) (origin, transaction.transaction_id) # type: ignore
) )
): ):
result = await self._handle_incoming_transaction( result = await self._handle_incoming_transaction(
@ -132,31 +137,33 @@ class FederationServer(FederationBase):
return result return result
async def _handle_incoming_transaction(self, origin, transaction, request_time): async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
""" Process an incoming transaction and return the HTTP response """ Process an incoming transaction and return the HTTP response
Args: Args:
origin (unicode): the server making the request origin: the server making the request
transaction (Transaction): incoming transaction transaction: incoming transaction
request_time (int): timestamp that the HTTP request arrived at request_time: timestamp that the HTTP request arrived at
Returns: Returns:
Deferred[(int, object)]: http response code and body HTTP response code and body
""" """
response = await self.transaction_actions.have_responded(origin, transaction) response = await self.transaction_actions.have_responded(origin, transaction)
if response: if response:
logger.debug( logger.debug(
"[%s] We've already responded to this request", "[%s] We've already responded to this request",
transaction.transaction_id, transaction.transaction_id, # type: ignore
) )
return response return response
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
# Reject if PDU count > 50 or EDU count > 100 # Reject if PDU count > 50 or EDU count > 100
if len(transaction.pdus) > 50 or ( if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
): ):
logger.info("Transaction PDU or EDU count too large. Returning 400") logger.info("Transaction PDU or EDU count too large. Returning 400")
@ -204,13 +211,13 @@ class FederationServer(FederationBase):
report back to the sending server. report back to the sending server.
""" """
received_pdus_counter.inc(len(transaction.pdus)) received_pdus_counter.inc(len(transaction.pdus)) # type: ignore
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
pdus_by_room = {} pdus_by_room = {} # type: Dict[str, List[EventBase]]
for p in transaction.pdus: for p in transaction.pdus: # type: ignore
if "unsigned" in p: if "unsigned" in p:
unsigned = p["unsigned"] unsigned = p["unsigned"]
if "age" in unsigned: if "age" in unsigned:
@ -254,7 +261,7 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but # require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu. # impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id): async def process_pdus_for_room(room_id: str):
logger.debug("Processing PDUs for %s", room_id) logger.debug("Processing PDUs for %s", room_id)
try: try:
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -310,7 +317,9 @@ class FederationServer(FederationBase):
TRANSACTION_CONCURRENCY_LIMIT, TRANSACTION_CONCURRENCY_LIMIT,
) )
async def on_context_state_request(self, origin, room_id, event_id): async def on_context_state_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -338,7 +347,9 @@ class FederationServer(FederationBase):
return 200, resp return 200, resp
async def on_state_ids_request(self, origin, room_id, event_id): async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
if not event_id: if not event_id:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
@ -354,7 +365,9 @@ class FederationServer(FederationBase):
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(self, room_id, event_id): async def _on_context_state_request_compute(
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id: if event_id:
pdus = await self.handler.get_state_for_pdu(room_id, event_id) pdus = await self.handler.get_state_for_pdu(room_id, event_id)
else: else:
@ -367,7 +380,9 @@ class FederationServer(FederationBase):
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
} }
async def on_pdu_request(self, origin, event_id): async def on_pdu_request(
self, origin: str, event_id: str
) -> Tuple[int, Union[JsonDict, str]]:
pdu = await self.handler.get_persisted_pdu(origin, event_id) pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu: if pdu:
@ -375,12 +390,16 @@ class FederationServer(FederationBase):
else: else:
return 404, "" return 404, ""
async def on_query_request(self, query_type, args): async def on_query_request(
self, query_type: str, args: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
received_queries_counter.labels(query_type).inc() received_queries_counter.labels(query_type).inc()
resp = await self.registry.on_query(query_type, args) resp = await self.registry.on_query(query_type, args)
return 200, resp return 200, resp
async def on_make_join_request(self, origin, room_id, user_id, supported_versions): async def on_make_join_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -397,7 +416,7 @@ class FederationServer(FederationBase):
async def on_invite_request( async def on_invite_request(
self, origin: str, content: JsonDict, room_version_id: str self, origin: str, content: JsonDict, room_version_id: str
): ) -> Dict[str, Any]:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version: if not room_version:
raise SynapseError( raise SynapseError(
@ -414,7 +433,9 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)} return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request(self, origin, content, room_id): async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
@ -434,7 +455,9 @@ class FederationServer(FederationBase):
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
} }
async def on_make_leave_request(self, origin, room_id, user_id): async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]:
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
@ -444,7 +467,9 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
async def on_send_leave_request(self, origin, content, room_id): async def on_send_leave_request(
self, origin: str, content: JsonDict, room_id: str
) -> dict:
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
@ -460,7 +485,9 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu) await self.handler.on_send_leave_request(origin, pdu)
return {} return {}
async def on_event_auth(self, origin, room_id, event_id): async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
with (await self._server_linearizer.queue((origin, room_id))): with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -471,15 +498,21 @@ class FederationServer(FederationBase):
return 200, res return 200, res
@log_function @log_function
def on_query_client_keys(self, origin, content): async def on_query_client_keys(
return self.on_query_request("client_keys", content) self, origin: str, content: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
return await self.on_query_request("client_keys", content)
async def on_query_user_devices(self, origin: str, user_id: str): async def on_query_user_devices(
self, origin: str, user_id: str
) -> Tuple[int, Dict[str, Any]]:
keys = await self.device_handler.on_federation_query_user_devices(user_id) keys = await self.device_handler.on_federation_query_user_devices(user_id)
return 200, keys return 200, keys
@trace @trace
async def on_claim_client_keys(self, origin, content): async def on_claim_client_keys(
self, origin: str, content: JsonDict
) -> Dict[str, Any]:
query = [] query = []
for user_id, device_keys in content.get("one_time_keys", {}).items(): for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
@ -488,7 +521,7 @@ class FederationServer(FederationBase):
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query) results = await self.store.claim_e2e_one_time_keys(query)
json_result = {} json_result = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items(): for key_id, json_bytes in keys.items():
@ -511,8 +544,13 @@ class FederationServer(FederationBase):
return {"one_time_keys": json_result} return {"one_time_keys": json_result}
async def on_get_missing_events( async def on_get_missing_events(
self, origin, room_id, earliest_events, latest_events, limit self,
): origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> Dict[str, list]:
with (await self._server_linearizer.queue((origin, room_id))): with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id) await self.check_server_matches_acl(origin_host, room_id)
@ -541,11 +579,11 @@ class FederationServer(FederationBase):
return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function @log_function
def on_openid_userinfo(self, token): async def on_openid_userinfo(self, token: str) -> Optional[str]:
ts_now_ms = self._clock.time_msec() ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms) return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
def _transaction_from_pdus(self, pdu_list): def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction:
"""Returns a new Transaction containing the given PDUs suitable for """Returns a new Transaction containing the given PDUs suitable for
transmission. transmission.
""" """
@ -558,7 +596,7 @@ class FederationServer(FederationBase):
destination=None, destination=None,
) )
async def _handle_received_pdu(self, origin, pdu): async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
""" Process a PDU received in a federation /send/ transaction. """ Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError. If the event is invalid, then this method throws a FederationError.
@ -579,10 +617,8 @@ class FederationServer(FederationBase):
until we try to backfill across the discontinuity. until we try to backfill across the discontinuity.
Args: Args:
origin (str): server which sent the pdu origin: server which sent the pdu
pdu (FrozenEvent): received pdu pdu: received pdu
Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match, or Raises: FederationError if the signatures / hash do not match, or
if the event was unacceptable for any other reason (eg, too large, if the event was unacceptable for any other reason (eg, too large,
@ -625,25 +661,27 @@ class FederationServer(FederationBase):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
): ):
ret = await self.handler.exchange_third_party_invite( ret = await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed sender_user_id, target_user_id, room_id, signed
) )
return ret return ret
async def on_exchange_third_party_invite_request(self, room_id, event_dict): async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: Dict
):
ret = await self.handler.on_exchange_third_party_invite_request( ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict room_id, event_dict
) )
return ret return ret
async def check_server_matches_acl(self, server_name, room_id): async def check_server_matches_acl(self, server_name: str, room_id: str):
"""Check if the given server is allowed by the server ACLs in the room """Check if the given server is allowed by the server ACLs in the room
Args: Args:
server_name (str): name of server, *without any port part* server_name: name of server, *without any port part*
room_id (str): ID of the room to check room_id: ID of the room to check
Raises: Raises:
AuthError if the server does not match the ACL AuthError if the server does not match the ACL
@ -661,15 +699,15 @@ class FederationServer(FederationBase):
raise AuthError(code=403, msg="Server is banned from room") raise AuthError(code=403, msg="Server is banned from room")
def server_matches_acl_event(server_name, acl_event): def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event """Check if the given server is allowed by the ACL event
Args: Args:
server_name (str): name of server, without any port part server_name: name of server, without any port part
acl_event (EventBase): m.room.server_acl event acl_event: m.room.server_acl event
Returns: Returns:
bool: True if this server is allowed by the ACLs True if this server is allowed by the ACLs
""" """
logger.debug("Checking %s against acl %s", server_name, acl_event.content) logger.debug("Checking %s against acl %s", server_name, acl_event.content)
@ -713,7 +751,7 @@ def server_matches_acl_event(server_name, acl_event):
return False return False
def _acl_entry_matches(server_name, acl_entry): def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
if not isinstance(acl_entry, six.string_types): if not isinstance(acl_entry, six.string_types):
logger.warning( logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
@ -732,13 +770,13 @@ class FederationHandlerRegistry(object):
self.edu_handlers = {} self.edu_handlers = {}
self.query_handlers = {} self.query_handlers = {}
def register_edu_handler(self, edu_type, handler): def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation EDU of the given type. federation EDU of the given type.
Args: Args:
edu_type (str): The type of the incoming EDU to register handler for edu_type: The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU handler: A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and of the given type. The arguments are the origin server name and
the EDU contents. the EDU contents.
""" """
@ -749,14 +787,16 @@ class FederationHandlerRegistry(object):
self.edu_handlers[edu_type] = handler self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler): def register_query_handler(
self, query_type: str, handler: Callable[[dict], defer.Deferred]
):
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation query of the given type. federation query of the given type.
Args: Args:
query_type (str): Category name of the query, which should match query_type: Category name of the query, which should match
the string used by make_query. the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle handler: Invoked to handle
incoming queries of this type. The return will be yielded incoming queries of this type. The return will be yielded
on and the result used as the response to the query request. on and the result used as the response to the query request.
""" """
@ -767,10 +807,11 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
async def on_edu(self, edu_type, origin, content): async def on_edu(self, edu_type: str, origin: str, content: dict):
handler = self.edu_handlers.get(edu_type) handler = self.edu_handlers.get(edu_type)
if not handler: if not handler:
logger.warning("No handler registered for EDU type %s", edu_type) logger.warning("No handler registered for EDU type %s", edu_type)
return
with start_active_span_from_edu(content, "handle_edu"): with start_active_span_from_edu(content, "handle_edu"):
try: try:
@ -780,7 +821,7 @@ class FederationHandlerRegistry(object):
except Exception: except Exception:
logger.exception("Failed to handle edu %r", edu_type) logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args): def on_query(self, query_type: str, args: dict) -> defer.Deferred:
handler = self.query_handlers.get(query_type) handler = self.query_handlers.get(query_type)
if not handler: if not handler:
logger.warning("No handler registered for query type %s", query_type) logger.warning("No handler registered for query type %s", query_type)
@ -807,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__() super(ReplicationFederationHandlerRegistry, self).__init__()
async def on_edu(self, edu_type, origin, content): async def on_edu(self, edu_type: str, origin: str, content: dict):
"""Overrides FederationHandlerRegistry """Overrides FederationHandlerRegistry
""" """
if not self.config.use_presence and edu_type == "m.presence": if not self.config.use_presence and edu_type == "m.presence":
@ -821,7 +862,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return await self._send_edu(edu_type=edu_type, origin=origin, content=content) return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
async def on_query(self, query_type, args): async def on_query(self, query_type: str, args: dict):
"""Overrides FederationHandlerRegistry """Overrides FederationHandlerRegistry
""" """
handler = self.query_handlers.get(query_type) handler = self.query_handlers.get(query_type)

View file

@ -183,6 +183,7 @@ commands = mypy \
synapse/events/spamcheck.py \ synapse/events/spamcheck.py \
synapse/federation/federation_base.py \ synapse/federation/federation_base.py \
synapse/federation/federation_client.py \ synapse/federation/federation_client.py \
synapse/federation/federation_server.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/auth.py \ synapse/handlers/auth.py \