Fix remaining mypy issues due to Twisted upgrade. (#9608)

This commit is contained in:
Patrick Cloke 2021-03-15 11:14:39 -04:00 committed by GitHub
parent 026503fa3b
commit d29b71aa50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 34 deletions

1
changelog.d/9608.misc Normal file
View File

@ -0,0 +1 @@
Fix incorrect type hints.

View File

@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol
class RedisProtocol:
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(

View File

@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
ITCPTransport,
)
from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
transport = None # type: Optional[ITCPTransport]
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
transport = None # type: Optional[ITCPTransport]
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return

View File

@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""

View File

@ -56,6 +56,7 @@ from prometheus_client import Counter
from zope.interface import Interface, implementer
from twisted.internet import task
from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""
# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection
delimiter = b"\n"
# Valid commands we expect to receive
@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics
assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands()
@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()

View File

@ -365,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
b"localhost",
6379,
self.connect_any_redis_attempts,
)
@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""
def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
transport = None # type: Optional[FakeTransport]
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
assert self.transport is not None
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)

View File

@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock
@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""