diff --git a/changelog.d/15465.misc b/changelog.d/15465.misc new file mode 100644 index 000000000..93ceaeafc --- /dev/null +++ b/changelog.d/15465.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/mypy.ini b/mypy.ini index 945f7925c..8fb87b9b7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,12 +33,6 @@ exclude = (?x) |synapse/storage/schema/ )$ -[mypy-synapse.federation.transport.client] -disallow_untyped_defs = False - -[mypy-synapse.http.matrixfederationclient] -disallow_untyped_defs = False - [mypy-synapse.metrics._reactor_metrics] disallow_untyped_defs = False # This module imports select.epoll. That exists on Linux, but doesn't on macOS. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4cf4957a4..ba34573d4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -280,15 +280,11 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) if not isinstance(transaction_data, dict): - # TODO we probably want an exception type specific to federation - # client validation. - raise TypeError("Backfill transaction_data is not a dict.") + raise InvalidResponseError("Backfill transaction_data is not a dict.") transaction_data_pdus = transaction_data.get("pdus") if not isinstance(transaction_data_pdus, list): - # TODO we probably want an exception type specific to federation - # client validation. - raise TypeError("transaction_data.pdus is not a list.") + raise InvalidResponseError("transaction_data.pdus is not a list.") room_version = await self.store.get_room_version(room_id) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index c05d598b7..bedbd23de 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -16,6 +16,7 @@ import logging import urllib from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -42,18 +43,21 @@ from synapse.api.urls import ( ) from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction -from synapse.http.matrixfederationclient import ByteParser +from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser from synapse.http.types import QueryParams from synapse.types import JsonDict from synapse.util import ExceptionBundle +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class TransportLayerClient: """Sends federation HTTP requests to other servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.client = hs.get_federation_http_client() self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled @@ -133,7 +137,7 @@ class TransportLayerClient: async def backfill( self, destination: str, room_id: str, event_tuples: Collection[str], limit: int - ) -> Optional[JsonDict]: + ) -> Optional[Union[JsonDict, list]]: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -388,6 +392,7 @@ class TransportLayerClient: # server was just having a momentary blip, the room will be out of # sync. ignore_backoff=True, + parser=LegacyJsonSendParser(), ) async def send_leave_v2( @@ -445,7 +450,11 @@ class TransportLayerClient: path = _create_v1_path("/invite/%s/%s", room_id, event_id) return await self.client.put_json( - destination=destination, path=path, data=content, ignore_backoff=True + destination=destination, + path=path, + data=content, + ignore_backoff=True, + parser=LegacyJsonSendParser(), ) async def send_invite_v2( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3302d4e48..634882487 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,7 +17,6 @@ import codecs import logging import random import sys -import typing import urllib.parse from http import HTTPStatus from io import BytesIO, StringIO @@ -30,9 +29,11 @@ from typing import ( Generic, List, Optional, + TextIO, Tuple, TypeVar, Union, + cast, overload, ) @@ -183,20 +184,61 @@ class MatrixFederationRequest: return self.json -class JsonParser(ByteParser[Union[JsonDict, list]]): +class _BaseJsonParser(ByteParser[T]): """A parser that buffers the response and tries to parse it as JSON.""" CONTENT_TYPE = "application/json" - def __init__(self) -> None: + def __init__( + self, validator: Optional[Callable[[Optional[object]], bool]] = None + ) -> None: + """ + Args: + validator: A callable which takes the parsed JSON value and returns + true if the value is valid. + """ self._buffer = StringIO() self._binary_wrapper = BinaryIOWrapper(self._buffer) + self._validator = validator def write(self, data: bytes) -> int: return self._binary_wrapper.write(data) - def finish(self) -> Union[JsonDict, list]: - return json_decoder.decode(self._buffer.getvalue()) + def finish(self) -> T: + result = json_decoder.decode(self._buffer.getvalue()) + if self._validator is not None and not self._validator(result): + raise ValueError( + f"Received incorrect JSON value: {result.__class__.__name__}" + ) + return result + + +class JsonParser(_BaseJsonParser[JsonDict]): + """A parser that buffers the response and tries to parse it as a JSON object.""" + + def __init__(self) -> None: + super().__init__(self._validate) + + @staticmethod + def _validate(v: Any) -> bool: + return isinstance(v, dict) + + +class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): + """Ensure the legacy responses of /send_join & /send_leave are correct.""" + + def __init__(self) -> None: + super().__init__(self._validate) + + @staticmethod + def _validate(v: Any) -> bool: + # Match [integer, JSON dict] + return ( + isinstance(v, list) + and len(v) == 2 + and type(v[0]) == int + and isinstance(v[1], dict) + ) async def _handle_response( @@ -313,9 +355,7 @@ async def _handle_response( class BinaryIOWrapper: """A wrapper for a TextIO which converts from bytes on the fly.""" - def __init__( - self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" - ): + def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"): self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.file = file @@ -793,7 +833,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: ... @overload @@ -825,8 +865,8 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser] = None, - ): + parser: Optional[ByteParser[T]] = None, + ) -> Union[JsonDict, T]: """Sends the specified json data using PUT Args: @@ -902,7 +942,7 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout if parser is None: - parser = JsonParser() + parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( self.reactor, @@ -924,7 +964,7 @@ class MatrixFederationHttpClient: timeout: Optional[int] = None, ignore_backoff: bool = False, args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: """Sends the specified json data using POST Args: @@ -998,7 +1038,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: ... @overload @@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient: timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser] = None, - ): + parser: Optional[ByteParser[T]] = None, + ) -> Union[JsonDict, T]: """GETs some json from the given host homeserver and path Args: @@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout if parser is None: - parser = JsonParser() + parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( self.reactor, @@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient: timeout: Optional[int] = None, ignore_backoff: bool = False, args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: + ) -> JsonDict: """Send a DELETE request to the remote expecting some json response Args: diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 33af8770f..129d7cfd9 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) @@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index fdd22a8e9..d89a91c59 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel from synapse.api.errors import RequestSendFailed from synapse.http.matrixfederationclient import ( - JsonParser, + ByteParser, MatrixFederationHttpClient, MatrixFederationRequest, ) @@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase): while not test_d.called: protocol.dataReceived(b"a" * chunk_size) sent += chunk_size - self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE) + self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE) - self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE) + self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE) f = self.failureResultOf(test_d) self.assertIsInstance(f.value, RequestSendFailed)