Make it possible to use dmypy (#9692)

Running `dmypy run` will do a `mypy` check while spinning up a daemon
that makes rerunning `dmypy run` a lot faster.

`dmypy` doesn't support `follow_imports = silent` and has
`local_partial_types` enabled, so this PR enables those options and
fixes the issues that were newly raised. Note that `local_partial_types`
will be enabled by default in upcoming mypy releases.
This commit is contained in:
Erik Johnston 2021-03-26 16:49:46 +00:00 committed by GitHub
parent 019010964d
commit b5efcb577e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 56 additions and 17 deletions

1
changelog.d/9692.misc Normal file
View file

@ -0,0 +1 @@
Make it possible to use `dmypy`.

View file

@ -1,12 +1,13 @@
[mypy] [mypy]
namespace_packages = True namespace_packages = True
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent follow_imports = normal
check_untyped_defs = True check_untyped_defs = True
show_error_codes = True show_error_codes = True
show_traceback = True show_traceback = True
mypy_path = stubs mypy_path = stubs
warn_unreachable = True warn_unreachable = True
local_partial_types = True
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run:
# #

View file

@ -558,6 +558,9 @@ class Auth:
Returns: Returns:
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
# This will always be set by the time Twisted calls us.
assert request.args is not None
query_params = request.args.get(b"access_token") query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@ -574,6 +577,8 @@ class Auth:
MissingClientTokenError: If there isn't a single access_token in the MissingClientTokenError: If there isn't a single access_token in the
request request
""" """
# This will always be set by the time Twisted calls us.
assert request.args is not None
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
query_params = request.args.get(b"access_token") query_params = request.args.get(b"access_token")

View file

@ -24,7 +24,7 @@ from ._base import Config, ConfigError
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache. # Map from canonicalised cache name to cache.
_CACHES = {} _CACHES = {} # type: Dict[str, Callable[[float], None]]
# a lock on the contents of _CACHES # a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock() _CACHES_LOCK = threading.Lock()
@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str:
return cache_name.lower() return cache_name.lower()
def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None]
):
"""Register a cache that's size can dynamically change """Register a cache that's size can dynamically change
Args: Args:

View file

@ -149,6 +149,9 @@ class OidcHandler:
Args: Args:
request: the incoming request from the browser. request: the incoming request from the browser.
""" """
# This will always be set by the time Twisted calls us.
assert request.args is not None
# The provider might redirect with an error. # The provider might redirect with an error.
# In that case, just display it as-is. # In that case, just display it as-is.
if b"error" in request.args: if b"error" in request.args:

View file

@ -262,7 +262,7 @@ logger = logging.getLogger(__name__)
# Block everything by default # Block everything by default
# A regex which matches the server_names to expose traces for. # A regex which matches the server_names to expose traces for.
# None means 'block everything'. # None means 'block everything'.
_homeserver_whitelist = None _homeserver_whitelist = None # type: Optional[re.Pattern[str]]
# Util methods # Util methods

View file

@ -104,7 +104,7 @@ tcp_outbound_commands_counter = Counter(
# A list of all connected protocols. This allows us to send metrics about the # A list of all connected protocols. This allows us to send metrics about the
# connections. # connections.
connected_connections = [] connected_connections = [] # type: List[BaseReplicationStreamProtocol]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -390,6 +390,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_identifier: str self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)

View file

@ -833,6 +833,9 @@ class UserMediaRestServlet(RestServlet):
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):

View file

@ -91,6 +91,9 @@ class SyncRestServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None
if b"from" in request.args: if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'. # /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'. # Lets be helpful and whine if we see a 'from'.

View file

@ -187,6 +187,8 @@ class PreviewUrlResource(DirectServeJsonResource):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)

View file

@ -104,6 +104,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
async def _async_render_POST(self, request: SynapseRequest): async def _async_render_POST(self, request: SynapseRequest):
# This will always be set by the time Twisted calls us.
assert request.args is not None
try: try:
session_id = get_username_mapping_session_cookie_from_request(request) session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e: except SynapseError as e:

View file

@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
caches_by_name = {} caches_by_name = {} # type: Dict[str, Sized]
collectors_by_name = {} # type: Dict collectors_by_name = {} # type: Dict[str, CacheMetric]
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View file

@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
# The from token should be the token from the last RDATA we got. # The from token should be the token from the last RDATA we got.
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token) self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()

View file

@ -15,7 +15,7 @@
import logging import logging
import os import os
from binascii import unhexlify from binascii import unhexlify
from typing import Tuple from typing import Optional, Tuple
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory
@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
test_server_connection_factory = None test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):

View file

@ -2,7 +2,7 @@ import json
import logging import logging
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IHostnameResolver,
IProtocol,
IPullProducer,
IPushProducer,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple, IResolverSimple,
ITransport, ITransport,
) )
@ -45,11 +48,11 @@ class FakeChannel:
wire). wire).
""" """
site = attr.ib(type=Site) site = attr.ib(type=Union[Site, "FakeSite"])
_reactor = attr.ib() _reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict)) result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1") _ip = attr.ib(type=str, default="127.0.0.1")
_producer = None _producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
@property @property
def json_body(self): def json_body(self):
@ -159,7 +162,11 @@ class FakeChannel:
Any cookines found are added to the given dict Any cookines found are added to the given dict
""" """
for h in self.headers.getRawHeaders("Set-Cookie"): headers = self.headers.getRawHeaders("Set-Cookie")
if not headers:
return
for h in headers:
parts = h.split(";") parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1) k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v cookies[k] = v
@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {} self._tcp_callbacks = {}
self._udp = [] self._udp = []
lookups = self.lookups = {} lookups = self.lookups = {} # type: Dict[str, str]
self._thread_callbacks = deque() # type: Deque[Callable[[], None]]() self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
@implementer(IResolverSimple) @implementer(IResolverSimple)
class FakeResolver: class FakeResolver:
@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self.nameResolver = SimpleResolverComplexifier(FakeResolver()) self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__() super().__init__()
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self) p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening() p.startListening()
@ -621,7 +631,9 @@ class FakeTransport:
self.disconnected = True self.disconnected = True
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: def connect_client(
reactor: ThreadedMemoryReactorClock, client_id: int
) -> Tuple[IProtocol, AccumulatingProtocol]:
""" """
Connect a client to a fake TCP transport. Connect a client to a fake TCP transport.