0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 21:13:54 +01:00

Type checking for FederationHandler (#7770)

fix a few things to make this pass mypy.
This commit is contained in:
Richard van der Hoff 2020-07-01 16:21:02 +01:00 committed by GitHub
parent 244dbb04f7
commit a6eae69ffe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 17 deletions

1
changelog.d/7770.misc Normal file
View file

@ -0,0 +1 @@
Fix up `synapse.handlers.federation` to pass mypy.

View file

@ -19,8 +19,9 @@
import itertools import itertools
import logging import logging
from collections import Container
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
# device and recognize the algorithm then we can work out the # device and recognize the algorithm then we can work out the
# exact key to expect. Otherwise check it matches any key we # exact key to expect. Otherwise check it matches any key we
# have for that device. # have for that device.
current_keys = [] # type: Container[str]
if device: if device:
keys = device.get("keys", {}).get("keys", {}) keys = device.get("keys", {}).get("keys", {})
@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
current_keys = keys.values() current_keys = keys.values()
elif device_id: elif device_id:
# We don't have any keys for the device ID. # We don't have any keys for the device ID.
current_keys = [] pass
else: else:
# The event didn't include a device ID, so we just look for # The event didn't include a device ID, so we just look for
# keys across all devices. # keys across all devices.
current_keys = ( current_keys = [
key key
for device in cached_devices for device in cached_devices
for key in device.get("keys", {}).get("keys", {}).values() for key in device.get("keys", {}).get("keys", {}).values()
) ]
# We now check that the sender key matches (one of) the expected # We now check that the sender key matches (one of) the expected
# keys. # keys.
@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN if e_type == EventTypes.Member and event.membership == Membership.JOIN
] ]
joined_domains = {} joined_domains = {} # type: Dict[str, int]
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = get_domain_from_id(u) dom = get_domain_from_id(u)
@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
try: try:
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
host_list = list(target_hosts)
try: try:
target_hosts.remove(origin) host_list.remove(origin)
target_hosts.insert(0, origin) host_list.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = await self.federation_client.send_join( ret = await self.federation_client.send_join(
target_hosts, event, room_version_obj host_list, event, room_version_obj
) )
origin = ret["origin"] origin = ret["origin"]
@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
# Try the host that we succesfully called /make_leave/ on first for # Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request. # the /send_leave/ request.
host_list = list(target_hosts)
try: try:
target_hosts.remove(origin) host_list.remove(origin)
target_hosts.insert(0, origin) host_list.insert(0, origin)
except ValueError: except ValueError:
pass pass
await self.federation_client.send_leave(target_hosts, event) await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event) context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)]) stream_id = await self.persist_events_and_notify([(event, context)])
@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
user_id: str, user_id: str,
membership: str, membership: str,
content: JsonDict = {}, content: JsonDict = {},
params: Optional[Dict[str, str]] = None, params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]: ) -> Tuple[str, EventBase, RoomVersion]:
( (
origin, origin,
@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
auth_events_ids = await self.auth.compute_auth_events( auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
@ -2293,10 +2299,10 @@ class FederationHandler(BaseHandler):
remote_auth_chain = await self.federation_client.get_event_auth( remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
except RequestSendFailed as e: except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the # The other side isn't around or doesn't implement the
# endpoint, so lets just bail out. # endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e) logger.info("Failed to get event auth from remote: %s", e1)
return context return context
seen_remotes = await self.store.have_seen_events( seen_remotes = await self.store.have_seen_events(
@ -2774,7 +2780,8 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content) logger.debug("Checking auth on event %r", event.content)
last_exception = None last_exception = None # type: Optional[Exception]
# for each public key in the 3pid invite event # for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event): for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try: try:
@ -2828,6 +2835,12 @@ class FederationHandler(BaseHandler):
return return
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if last_exception is None:
# we can only get here if get_public_keys() returned an empty list
# TODO: make this better
raise RuntimeError("no public key in invite event")
raise last_exception raise last_exception
async def _check_key_revocation(self, public_key, url): async def _check_key_revocation(self, public_key, url):

View file

@ -184,6 +184,7 @@ commands = mypy \
synapse/handlers/auth.py \ synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \ synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \ synapse/handlers/directory.py \
synapse/handlers/federation.py \
synapse/handlers/oidc_handler.py \ synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/room_member.py \ synapse/handlers/room_member.py \