0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-01 10:18:54 +02:00

Fix-up type hints in tests/server.py. (#15084)

This file was being ignored by mypy, we remove that
and add the missing type hints & deal with any fallout.
This commit is contained in:
Patrick Cloke 2023-02-17 13:19:38 -05:00 committed by GitHub
parent 61bfcd669a
commit c9b9143655
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 226 additions and 129 deletions

1
changelog.d/15084.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -31,8 +31,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/server.py
)$ )$
[mypy-synapse.federation.transport.client] [mypy-synapse.federation.transport.client]

View file

@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock from unittest.mock import Mock
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock from ..utils import MockClock
if TYPE_CHECKING:
from twisted.internet.testing import MemoryReactor
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:

View file

@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator, IOpenSSLClientConnectionCreator,
IProtocolFactory, IProtocolFactory,
) )
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent from twisted.web.client import Agent
@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else: else:
assert isinstance(proxy_server_transport, FakeTransport) assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol c2s_transport.other = server_ssl_protocol
self.reactor.advance(0) self.reactor.advance(0)

View file

@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol, _WrappingProtocol,
) )
from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else: else:
assert isinstance(proxy_server_transport, FakeTransport) assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol c2s_transport.other = server_ssl_protocol
self.reactor.advance(0) self.reactor.advance(0)

View file

@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
from tests.server import FakeChannel, make_request from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless from tests.unittest import override_config, skip_unless
@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token) channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
# Now try to exchange the login token # Now try to exchange the login token, it should fail.
channel = make_request( self.helper.login_via_token(login_token, 403)
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
# It should have failed
self.assertEqual(channel.code, 403)
@override_config( @override_config(
{ {

View file

@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
@ -67,6 +68,7 @@ class RestHelper:
""" """
hs: HomeServer hs: HomeServer
reactor: MemoryReactorClock
site: Site site: Site
auth_user_id: Optional[str] auth_user_id: Optional[str]
@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
path, path,
@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason data["reason"] = reason
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
path, path,
@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {}) data.update(extra_data or {})
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"PUT", "PUT",
path, path,
@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"PUT", "PUT",
path, path,
@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}" path = path + f"?access_token={tok}"
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
path, path,
@ -488,7 +490,7 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
channel = make_request(self.hs.get_reactor(), self.site, method, path, content) channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
FakeSite(resource, self.hs.get_reactor()), FakeSite(resource, self.reactor),
"POST", "POST",
path, path,
content=image_data, content=image_data,
@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request expect_code: The return code to expect from attempting the whoami request
""" """
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
"account/whoami", "account/whoami",
@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]: ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC """Log in (as a new user) via OIDC
Returns the result of the final token login. Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body assert m, channel.text_body
login_token = m.group(1) login_token = m.group(1)
# finally, submit the matrix login token to the login API, which gives us our return self.login_via_token(login_token, expected_status), grant
# matrix access token and device id.
def login_via_token(
self,
login_token: str,
expected_status: int = 200,
) -> JsonDict:
"""Submit the matrix login token to the login API, which gives us our
matrix access token and device id.Log in (as a new user) via OIDC
Returns the result of the token login.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
"public_base_url".
Also requires the login servlet and the OIDC callback resource to be mounted at
the normal places.
"""
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
"/login", "/login",
@ -684,7 +704,7 @@ class RestHelper:
assert ( assert (
channel.code == expected_status channel.code == expected_status
), f"unexpected status in response: {channel.code}" ), f"unexpected status in response: {channel.code}"
return channel.json_body, grant return channel.json_body
def auth_via_oidc( def auth_via_oidc(
self, self,
@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs): with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code # now hit the callback URI with the right params and a made-up code
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
callback_uri, callback_uri,
@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to # is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy. # to keep Synapse happy.
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
uri, uri,
@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel) location = get_location(channel)
parts = urllib.parse.urlsplit(location) parts = urllib.parse.urlsplit(location)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
urllib.parse.urlunsplit(("", "") + parts[2:]), urllib.parse.urlunsplit(("", "") + parts[2:]),
@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id}) + urllib.parse.urlencode({"session": ui_auth_session_id})
) )
# hit the redirect url (which will issue a cookie and state) # hit the redirect url (which will issue a cookie and state)
channel = make_request( channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
)
# that should serve a confirmation page # that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies) channel.extract_cookies(cookies)

View file

@ -22,20 +22,25 @@ import warnings
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import ( from typing import (
Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List, List,
MutableMapping, MutableMapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
cast,
) )
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque, ParamSpec
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp
@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress, IAddress,
IConnector,
IConsumer, IConsumer,
IHostnameResolver, IHostnameResolver,
IProducer,
IProtocol, IProtocol,
IPullProducer, IPullProducer,
IPushProducer, IPushProducer,
@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple, IResolverSimple,
ITransport, ITransport,
) )
from twisted.internet.protocol import ClientFactory, DatagramProtocol
from twisted.python import threadpool
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
R = TypeVar("R")
P = ParamSpec("P")
# the type of thing that can be passed into `make_request` in the headers list # the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@ -98,12 +111,14 @@ class TimedOutException(Exception):
""" """
@implementer(IConsumer) @implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class FakeChannel: class FakeChannel:
""" """
A fake Twisted Web Channel (the part that interfaces with the A fake Twisted Web Channel (the part that interfaces with the
wire). wire).
See twisted.web.http.HTTPChannel.
""" """
site: Union[Site, "FakeSite"] site: Union[Site, "FakeSite"]
@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed. Raises an exception if the request has not yet completed.
""" """
if not self.is_finished: if not self.is_finished():
raise Exception("Request not yet completed") raise Exception("Request not yet completed")
return self.result["body"].decode("utf8") return self.result["body"].decode("utf8")
@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i) h.addRawHeader(*i)
return h return h
def writeHeaders(self, version, code, reason, headers): def writeHeaders(
self, version: bytes, code: bytes, reason: bytes, headers: Headers
) -> None:
self.result["version"] = version self.result["version"] = version
self.result["code"] = code self.result["code"] = code
self.result["reason"] = reason self.result["reason"] = reason
self.result["headers"] = headers self.result["headers"] = headers
def write(self, content: bytes) -> None: def write(self, data: bytes) -> None:
assert isinstance(content, bytes), "Should be bytes! " + repr(content) assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result: if "body" not in self.result:
self.result["body"] = b"" self.result["body"] = b""
self.result["body"] += content self.result["body"] += data
def writeSequence(self, data: Iterable[bytes]) -> None:
for x in data:
self.write(x)
def loseConnection(self) -> None:
self.unregisterProducer()
self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer. # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
def registerProducer( # type: ignore[override] def registerProducer(self, producer: IProducer, streaming: bool) -> None:
self, # TODO This should ensure that the IProducer is an IPushProducer or
producer: Union[IPullProducer, IPushProducer], # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
streaming: bool, # implement those, but doesn't declare it.
) -> None: self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self._producer = producer
self.producerStreaming = streaming self.producerStreaming = streaming
def _produce() -> None: def _produce() -> None:
@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None self._producer = None
def stopProducing(self) -> None:
if self._producer is not None:
self._producer.stopProducing()
def pauseProducing(self) -> None:
raise NotImplementedError()
def resumeProducing(self) -> None:
raise NotImplementedError()
def requestDone(self, _self: Request) -> None: def requestDone(self, _self: Request) -> None:
self.result["done"] = True self.result["done"] = True
if isinstance(_self, SynapseRequest): if isinstance(_self, SynapseRequest):
@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886 self.experimental_cors_msc3886 = experimental_cors_msc3886
def getResourceFor(self, request): def getResourceFor(self, request: Request) -> IResource:
return self._resource return self._resource
def make_request( def make_request(
reactor, reactor: MemoryReactorClock,
site: Union[Site, FakeSite], site: Union[Site, FakeSite],
method: Union[bytes, str], method: Union[bytes, str],
path: Union[bytes, str], path: Union[bytes, str],
@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.
""" """
def __init__(self): def __init__(self) -> None:
self.threadpool = ThreadPool(self) self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
self._udp = [] self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {} self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque() self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups lookups = self.lookups
@implementer(IResolverSimple) @implementer(IResolverSimple)
class FakeResolver: class FakeResolver:
def getHostByName(self, name, timeout=None): def getHostByName(
self, name: str, timeout: Optional[Sequence[int]] = None
) -> "Deferred[str]":
if name not in lookups: if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name]) return succeed(lookups[name])
@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError() raise NotImplementedError()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): def listenUDP(
self,
port: int,
protocol: DatagramProtocol,
interface: str = "",
maxPacketSize: int = 8196,
) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self) p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening() p.startListening()
self._udp.append(p) self._udp.append(p)
return p return p
def callFromThread(self, callback, *args, **kwargs): def callFromThread(
self, callable: Callable[..., Any], *args: object, **kwargs: object
) -> None:
""" """
Make the callback fire in the next reactor iteration. Make the callback fire in the next reactor iteration.
""" """
cb = lambda: callback(*args, **kwargs) cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a # it's not safe to call callLater() here, so we append the callback to a
# separate queue. # separate queue.
self._thread_callbacks.append(cb) self._thread_callbacks.append(cb)
def getThreadPool(self): def callInThread(
return self.threadpool self, callable: Callable[..., Any], *args: object, **kwargs: object
) -> None:
raise NotImplementedError()
def add_tcp_client_callback(self, host: str, port: int, callback: Callable): def suggestThreadPoolSize(self, size: int) -> None:
raise NotImplementedError()
def getThreadPool(self) -> "threadpool.ThreadPool":
# Cast to match super-class.
return cast(threadpool.ThreadPool, self.threadpool)
def add_tcp_client_callback(
self, host: str, port: int, callback: Callable[[], None]
) -> None:
"""Add a callback that will be invoked when we receive a connection """Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`. attempt to the given IP/port using `connectTCP`.
@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
self._tcp_callbacks[(host, port)] = callback self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): def connectTCP(
self,
host: str,
port: int,
factory: ClientFactory,
timeout: float = 30,
bindAddress: Optional[Tuple[str, int]] = None,
) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}.""" """Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP( conn = super().connectTCP(
@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn return conn
def advance(self, amount): def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that # first advance our reactor's time, and run any "callLater" callbacks that
# makes ready # makes ready
super().advance(amount) super().advance(amount)
@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool: class ThreadPool:
""" """
Threadless thread pool. Threadless thread pool.
See twisted.python.threadpool.ThreadPool
""" """
def __init__(self, reactor): def __init__(self, reactor: IReactorTime):
self._reactor = reactor self._reactor = reactor
def start(self): def start(self) -> None:
pass pass
def stop(self): def stop(self) -> None:
pass pass
def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def callInThreadWithCallback(
def _(res): self,
onResult: Callable[[bool, Union[Failure, R]], None],
function: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> "Deferred[None]":
def _(res: Any) -> None:
if isinstance(res, Failure): if isinstance(res, Failure):
onResult(False, res) onResult(False, res)
else: else:
onResult(True, res) onResult(True, res)
d = Deferred() d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs)) d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_) d.addBoth(_)
self._reactor.callLater(0, d.callback, True) self._reactor.callLater(0, d.callback, True)
@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases: for database in server.get_datastores().databases:
pool = database._db_pool pool = database._db_pool
def runWithConnection(func, *args, **kwargs): def runWithConnection(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> Awaitable[R]:
return threads.deferToThreadPool( return threads.deferToThreadPool(
pool._reactor, pool._reactor,
pool.threadpool, pool.threadpool,
@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs, **kwargs,
) )
def runInteraction(interaction, *args, **kwargs): def runInteraction(
desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
) -> Awaitable[R]:
return threads.deferToThreadPool( return threads.deferToThreadPool(
pool._reactor, pool._reactor,
pool.threadpool, pool.threadpool,
pool._runInteraction, pool._runInteraction,
interaction, desc,
func,
*args, *args,
**kwargs, **kwargs,
) )
pool.runWithConnection = runWithConnection pool.runWithConnection = runWithConnection # type: ignore[assignment]
pool.runInteraction = runInteraction pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool # Replace the thread pool with a threadless 'thread' pool
pool.threadpool = ThreadPool(clock._reactor) pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True pool.running = True
# We've just changed the Databases to run DB transactions on the same # We've just changed the Databases to run DB transactions on the same
@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport) @implementer(ITransport)
@attr.s(cmp=False) @attr.s(cmp=False, auto_attribs=True)
class FakeTransport: class FakeTransport:
""" """
A twisted.internet.interfaces.ITransport implementation which sends all its data A twisted.internet.interfaces.ITransport implementation which sends all its data
@ -588,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances. If you want bidirectional communication, you'll need two instances.
""" """
other = attr.ib() other: IProtocol
"""The Protocol object which will receive any data written to this transport. """The Protocol object which will receive any data written to this transport.
:type: twisted.internet.interfaces.IProtocol
""" """
_reactor = attr.ib() _reactor: IReactorTime
"""Test reactor """Test reactor
:type: twisted.internet.interfaces.IReactorTime
""" """
_protocol = attr.ib(default=None) _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set """The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc. will get called back for connectionLost() notifications etc.
""" """
_peer_address: Optional[IAddress] = attr.ib(default=None) _peer_address: IAddress = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
)
"""The value to be returned by getPeer""" """The value to be returned by getPeer"""
_host_address: Optional[IAddress] = attr.ib(default=None) _host_address: IAddress = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
)
"""The value to be returned by getHost""" """The value to be returned by getHost"""
disconnecting = False disconnecting = False
disconnected = False disconnected = False
connected = True connected = True
buffer = attr.ib(default=b"") buffer: bytes = b""
producer = attr.ib(default=None) producer: Optional[IPushProducer] = None
autoflush = attr.ib(default=True) autoflush: bool = True
def getPeer(self) -> Optional[IAddress]: def getPeer(self) -> IAddress:
return self._peer_address return self._peer_address
def getHost(self) -> Optional[IAddress]: def getHost(self) -> IAddress:
return self._host_address return self._host_address
def loseConnection(self, reason=None): def loseConnection(self) -> None:
if not self.disconnecting: if not self.disconnecting:
logger.info("FakeTransport: loseConnection(%s)", reason) logger.info("FakeTransport: loseConnection()")
self.disconnecting = True self.disconnecting = True
if self._protocol: if self._protocol:
self._protocol.connectionLost(reason) self._protocol.connectionLost(
Failure(RuntimeError("FakeTransport.loseConnection()"))
)
# if we still have data to write, delay until that is done # if we still have data to write, delay until that is done
if self.buffer: if self.buffer:
@ -640,38 +717,38 @@ class FakeTransport:
self.connected = False self.connected = False
self.disconnected = True self.disconnected = True
def abortConnection(self): def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()") logger.info("FakeTransport: abortConnection()")
if not self.disconnecting: if not self.disconnecting:
self.disconnecting = True self.disconnecting = True
if self._protocol: if self._protocol:
self._protocol.connectionLost(None) self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True self.disconnected = True
def pauseProducing(self): def pauseProducing(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer.pauseProducing() self.producer.pauseProducing()
def resumeProducing(self): def resumeProducing(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer.resumeProducing() self.producer.resumeProducing()
def unregisterProducer(self): def unregisterProducer(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer = None self.producer = None
def registerProducer(self, producer, streaming): def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer self.producer = producer
self.producerStreaming = streaming self.producerStreaming = streaming
def _produce(): def _produce() -> None:
if not self.producer: if not self.producer:
# we've been unregistered # we've been unregistered
return return
@ -683,7 +760,7 @@ class FakeTransport:
if not streaming: if not streaming:
self._reactor.callLater(0.0, _produce) self._reactor.callLater(0.0, _produce)
def write(self, byt): def write(self, byt: bytes) -> None:
if self.disconnecting: if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport") raise Exception("Writing to disconnecting FakeTransport")
@ -695,11 +772,11 @@ class FakeTransport:
if self.autoflush: if self.autoflush:
self._reactor.callLater(0.0, self.flush) self._reactor.callLater(0.0, self.flush)
def writeSequence(self, seq): def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq: for x in seq:
self.write(x) self.write(x)
def flush(self, maxbytes=None): def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer: if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the # nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol # TLSMemoryBIOProtocol
@ -750,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer): class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, cleanup_func: Callable[[Callable[[], None]], None],
name="test", name: str = "test",
config=None, config: Optional[HomeServerConfig] = None,
reactor=None, reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer, homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs, **kwargs: Any,
): ) -> HomeServer:
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. are passed to the Homeserver constructor.
@ -775,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase. HomeserverTestCase.
""" """
if reactor is None: if reactor is None:
from twisted.internet import reactor from twisted.internet import reactor as _reactor
reactor = cast(ISynapseReactor, _reactor)
if config is None: if config is None:
config = default_config(name, parse=True) config = default_config(name, parse=True)
config.caches.resize_all_caches() config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs: if "clock" not in kwargs:
kwargs["clock"] = MockClock() kwargs["clock"] = MockClock()
@ -832,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off # Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb() # the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine): if isinstance(db_engine, PostgresEngine):
import psycopg2.extensions
db_conn = db_engine.module.connect( db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, database=POSTGRES_BASE_DB,
user=POSTGRES_USER, user=POSTGRES_USER,
@ -839,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT, port=POSTGRES_PORT,
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
) )
assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True db_conn.autocommit = True
cur = db_conn.cursor() cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@ -867,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks() hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine): if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0] database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL # We need to do cleanup on PostgreSQL
def cleanup(): def cleanup() -> None:
import psycopg2 import psycopg2
import psycopg2.extensions
# Close all the db pools # Close all the db pools
database._db_pool.close() database_pool._db_pool.close()
dropped = False dropped = False
@ -886,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT, port=POSTGRES_PORT,
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
) )
assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True db_conn.autocommit = True
cur = db_conn.cursor() cur = db_conn.cursor()
@ -918,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it # Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one # because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg) # beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p): async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest() return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash hs.get_auth_handler().hash = hash # type: ignore[assignment]
async def validate_hash(p, h): async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing. # Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs) _make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver # Load any configured modules into the homeserver
module_api = hs.get_module_api() module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules: for module, module_config in hs.config.modules.loaded_modules:
module(config=config, api=module_api) module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs) load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs) load_legacy_third_party_event_rules(hs)

View file

@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
@ -82,7 +82,7 @@ from tests.server import (
) )
from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb from tests.utils import checked_cast, default_config, setupdb
setupdb() setupdb()
setup_logging() setup_logging()
@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper from tests.rest.client.utils import RestHelper
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) self.helper = RestHelper(
self.hs,
checked_cast(MemoryReactorClock, self.hs.get_reactor()),
self.site,
getattr(self, "user_id", None),
)
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth: