0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-05-29 00:43:45 +02:00

Fix remaining mypy issues due to Twisted upgrade. (#9608)

This commit is contained in:
Patrick Cloke 2021-03-15 11:14:39 -04:00 committed by GitHub
parent 026503fa3b
commit d29b71aa50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 34 deletions

1
changelog.d/9608.misc Normal file
View file

@ -0,0 +1 @@
Fix incorrect type hints.

View file

@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol from twisted.internet import protocol
class RedisProtocol: class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ... async def ping(self) -> None: ...
async def set( async def set(

View file

@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
IHostResolution, IHostResolution,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IResolutionReceiver, IResolutionReceiver,
ITCPTransport,
) )
from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator from twisted.internet.task import Cooperator
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data.""" """A protocol which immediately errors upon receiving data."""
transport = None # type: Optional[ITCPTransport]
def __init__(self, deferred: defer.Deferred): def __init__(self, deferred: defer.Deferred):
self.deferred = deferred self.deferred = deferred
@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get # Close the connection (forcefully) since all the data will get
# discarded anyway. # discarded anyway.
assert self.transport is not None
self.transport.abortConnection() self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None: def dataReceived(self, data: bytes) -> None:
self._maybe_fail() self._maybe_fail()
def connectionLost(self, reason: Failure) -> None: def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail() self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
transport = None # type: Optional[ITCPTransport]
def __init__( def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
): ):
@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize()) self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get # Close the connection (forcefully) since all the data will get
# discarded anyway. # discarded anyway.
assert self.transport is not None
self.transport.abortConnection() self.transport.abortConnection()
def connectionLost(self, reason: Failure) -> None: def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do. # If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called: if self.deferred.called:
return return

View file

@ -302,7 +302,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection hs, outbound_redis_connection
) )
hs.get_reactor().connectTCP( hs.get_reactor().connectTCP(
hs.config.redis.redis_host, hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port, hs.config.redis.redis_port,
self._factory, self._factory,
) )
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host host = hs.config.worker_replication_host
port = hs.config.worker_replication_port port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory) hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams.""" """Get a map from stream name to all streams."""

View file

@ -56,6 +56,7 @@ from prometheus_client import Counter
from zope.interface import Interface, implementer from zope.interface import Interface, implementer
from twisted.internet import task from twisted.internet import task
from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command) (if they send a `PING` command)
""" """
# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection
delimiter = b"\n" delimiter = b"\n"
# Valid commands we expect to receive # Valid commands we expect to receive
@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
connected_connections.append(self) # Register connection for metrics connected_connections.append(self) # Register connection for metrics
assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks self.transport.registerProducer(self, True) # For the *Producing callbacks
self._send_pending_commands() self._send_pending_commands()
@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info( logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id() "[%s] Failed to close connection gracefully, aborting", self.id()
) )
assert self.transport is not None
self.transport.abortConnection() self.transport.abortConnection()
else: else:
if now - self.last_sent_command >= PING_TIME: if now - self.last_sent_command >= PING_TIME:
@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def close(self): def close(self):
logger.warning("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec() self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection() self.transport.loseConnection()
self.on_connection_closed() self.on_connection_closed()
@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason) logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure): if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc() connection_close_counter.labels(reason.type.__name__).inc()
else: else:
connection_close_counter.labels(reason.__class__.__name__).inc() connection_close_counter.labels(reason.__class__.__name__).inc()

View file

@ -365,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect factory.continueTrying = reconnect
reactor = hs.get_reactor() reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
return factory.handler return factory.handler

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self.site) channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection() server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection() client_to_server_transport.loseConnection()
return request_factory.request return channel.request
def assert_request_is_get_repl_stream_updates( def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str self, request: SynapseRequest, stream_name: str
@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled: if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server. # Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback( self.reactor.add_tcp_client_callback(
"localhost", b"localhost",
6379, 6379,
self.connect_any_redis_attempts, self.connect_any_redis_attempts,
) )
@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
while clients: while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost") self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379) self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel): class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers. """A HTTPChannel that wraps pull producers to push producers.
@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
""" """
def __init__( def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site self, reactor: IReactorTime, request_factory: Type[Request], site: Site
): ):
super().__init__() super().__init__()
self.reactor = reactor self.reactor = reactor
@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"]) request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer: class _PullToPushProducer:
"""A push producer that wraps a pull producer.""" """A push producer that wraps a pull producer."""
@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol): class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server.""" """A connection from a client talking to the fake Redis server."""
transport = None # type: Optional[FakeTransport]
def __init__(self, server: FakeRedisPubSubServer): def __init__(self, server: FakeRedisPubSubServer):
self._server = server self._server = server
self._reader = hiredis.Reader() self._reader = hiredis.Reader()
@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg): def send(self, msg):
"""Send a message back to the client.""" """Send a message back to the client."""
assert self.transport is not None
raw = self.encode(msg).encode("utf-8") raw = self.encode(msg).encode("utf-8")
self.transport.write(raw) self.transport.write(raw)

View file

@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTCP, IReactorTCP,
IResolverSimple, IResolverSimple,
ITransport,
) )
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock return clock, hs_clock
@implementer(ITransport)
@attr.s(cmp=False) @attr.s(cmp=False)
class FakeTransport: class FakeTransport:
""" """