From 3ab861ab9eaf54a336a5a900eeb8402c3e9ed811 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Oct 2023 14:28:05 -0400 Subject: [PATCH] Fix type hint errors from Twisted trunk (#16526) --- changelog.d/16526.misc | 1 + synapse/util/file_consumer.py | 16 +++++++++++----- tests/handlers/test_appservice.py | 1 + tests/http/server/_base.py | 2 +- tests/http/test_matrixfederationclient.py | 2 +- tests/unittest.py | 3 ++- 6 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 changelog.d/16526.misc diff --git a/changelog.d/16526.misc b/changelog.d/16526.misc new file mode 100644 index 000000000..93ceaeafc --- /dev/null +++ b/changelog.d/16526.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 46771a401..26b46be5e 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # limitations under the License. import queue -from typing import BinaryIO, Optional, Union, cast +from typing import Any, BinaryIO, Optional, Union, cast from twisted.internet import threads from twisted.internet.defer import Deferred @@ -58,7 +58,9 @@ class BackgroundFileConsumer: self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() # Deferred that is resolved when finished writing - self._finished_deferred: Optional[Deferred[None]] = None + # + # This is really Deferred[None], but mypy doesn't seem to like that. + self._finished_deferred: Optional[Deferred[Any]] = None # If the _writer thread throws an exception it gets stored here. self._write_exception: Optional[Exception] = None @@ -80,9 +82,13 @@ class BackgroundFileConsumer: self.streaming = streaming self._finished_deferred = run_in_background( threads.deferToThreadPool, - self._reactor, - self._reactor.getThreadPool(), - self._writer, + # mypy seems to get confused with the chaining of ParamSpec from + # run_in_background to deferToThreadPool. + # + # For Twisted trunk, ignore arg-type; for Twisted release ignore unused-ignore. + self._reactor, # type: ignore[arg-type,unused-ignore] + self._reactor.getThreadPool(), # type: ignore[arg-type,unused-ignore] + self._writer, # type: ignore[arg-type,unused-ignore] ) if not streaming: self._producer.resumeProducing() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 867dbd600..c888d1ff0 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -156,6 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): result = self.successResultOf( defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias)) ) + assert result is not None self.mock_as_api.query_alias.assert_called_once_with( interested_service, room_alias_str diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 36472e57a..d524c183f 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -335,7 +335,7 @@ class Deferred__next__Patch: self._request_number = request_number self._seen_awaits = seen_awaits - self._original_Deferred___next__ = Deferred.__next__ + self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore] # The number of `await`s on `Deferred`s we have seen so far. self.awaits_seen = 0 diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index ab94f3f67..bf1d28769 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase): """ @defer.inlineCallbacks - def do_request() -> Generator["Deferred[object]", object, object]: + def do_request() -> Generator["Deferred[Any]", object, object]: with LoggingContext("one") as context: fetch_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar") diff --git a/tests/unittest.py b/tests/unittest.py index 99ad02eb0..79c47fc3c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -30,6 +30,7 @@ from typing import ( Generic, Iterable, List, + Mapping, NoReturn, Optional, Tuple, @@ -251,7 +252,7 @@ class TestCase(unittest.TestCase): except AssertionError as e: raise (type(e))(f"Assert error for '.{key}':") from e - def assert_dict(self, required: dict, actual: dict) -> None: + def assert_dict(self, required: Mapping, actual: Mapping) -> None: """Does a partial assert of a dict. Args: