mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 13:01:34 +01:00
db95b75515
When running unit tests, we patch the database connection pool so that it runs queries "synchronously". This is ok, except that if any queries are launched before we do the patching, those queries get left in limbo and never complete. To fix this, let's change the way we do the switcheroo, by patching out the method which creates the connection pool in the first place.
1159 lines
37 KiB
Python
1159 lines
37 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
import hashlib
|
|
import ipaddress
|
|
import json
|
|
import logging
|
|
import os
|
|
import os.path
|
|
import sqlite3
|
|
import time
|
|
import uuid
|
|
import warnings
|
|
from collections import deque
|
|
from io import SEEK_END, BytesIO
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Deque,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
MutableMapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
)
|
|
from unittest.mock import Mock, patch
|
|
|
|
import attr
|
|
from incremental import Version
|
|
from typing_extensions import ParamSpec
|
|
from zope.interface import implementer
|
|
|
|
import twisted
|
|
from twisted.enterprise import adbapi
|
|
from twisted.internet import address, tcp, threads, udp
|
|
from twisted.internet._resolver import SimpleResolverComplexifier
|
|
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
|
from twisted.internet.error import DNSLookupError
|
|
from twisted.internet.interfaces import (
|
|
IAddress,
|
|
IConnector,
|
|
IConsumer,
|
|
IHostnameResolver,
|
|
IListeningPort,
|
|
IProducer,
|
|
IProtocol,
|
|
IPullProducer,
|
|
IPushProducer,
|
|
IReactorPluggableNameResolver,
|
|
IReactorTime,
|
|
IResolverSimple,
|
|
ITransport,
|
|
)
|
|
from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
|
|
from twisted.python import threadpool
|
|
from twisted.python.failure import Failure
|
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
|
from twisted.web.http_headers import Headers
|
|
from twisted.web.resource import IResource
|
|
from twisted.web.server import Request, Site
|
|
|
|
from synapse.config.database import DatabaseConnectionConfig
|
|
from synapse.config.homeserver import HomeServerConfig
|
|
from synapse.events.presence_router import load_legacy_presence_router
|
|
from synapse.handlers.auth import load_legacy_password_auth_providers
|
|
from synapse.http.site import SynapseRequest
|
|
from synapse.logging.context import ContextResourceUsage
|
|
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
|
|
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
|
load_legacy_third_party_event_rules,
|
|
)
|
|
from synapse.server import HomeServer
|
|
from synapse.storage import DataStore
|
|
from synapse.storage.database import LoggingDatabaseConnection, make_pool
|
|
from synapse.storage.engines import BaseDatabaseEngine, create_engine
|
|
from synapse.storage.prepare_database import prepare_database
|
|
from synapse.types import ISynapseReactor, JsonDict
|
|
from synapse.util import Clock
|
|
|
|
from tests.utils import (
|
|
LEAVE_DB,
|
|
POSTGRES_BASE_DB,
|
|
POSTGRES_HOST,
|
|
POSTGRES_PASSWORD,
|
|
POSTGRES_PORT,
|
|
POSTGRES_USER,
|
|
SQLITE_PERSIST_DB,
|
|
USE_POSTGRES_FOR_TESTS,
|
|
MockClock,
|
|
default_config,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
R = TypeVar("R")
|
|
P = ParamSpec("P")
|
|
|
|
# the type of thing that can be passed into `make_request` in the headers list
|
|
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
|
|
|
# A pre-prepared SQLite DB that is used as a template when creating new SQLite
|
|
# DB each test run. This dramatically speeds up test set up when using SQLite.
|
|
PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
|
|
|
|
|
|
class TimedOutException(Exception):
|
|
"""
|
|
A web query timed out.
|
|
"""
|
|
|
|
|
|
@implementer(ITransport, IPushProducer, IConsumer)
|
|
@attr.s(auto_attribs=True)
|
|
class FakeChannel:
|
|
"""
|
|
A fake Twisted Web Channel (the part that interfaces with the
|
|
wire).
|
|
|
|
See twisted.web.http.HTTPChannel.
|
|
"""
|
|
|
|
site: Union[Site, "FakeSite"]
|
|
_reactor: MemoryReactorClock
|
|
result: dict = attr.Factory(dict)
|
|
_ip: str = "127.0.0.1"
|
|
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
|
|
resource_usage: Optional[ContextResourceUsage] = None
|
|
_request: Optional[Request] = None
|
|
|
|
@property
|
|
def request(self) -> Request:
|
|
assert self._request is not None
|
|
return self._request
|
|
|
|
@request.setter
|
|
def request(self, request: Request) -> None:
|
|
assert self._request is None
|
|
self._request = request
|
|
|
|
@property
|
|
def json_body(self) -> JsonDict:
|
|
body = json.loads(self.text_body)
|
|
assert isinstance(body, dict)
|
|
return body
|
|
|
|
@property
|
|
def json_list(self) -> List[JsonDict]:
|
|
body = json.loads(self.text_body)
|
|
assert isinstance(body, list)
|
|
return body
|
|
|
|
@property
|
|
def text_body(self) -> str:
|
|
"""The body of the result, utf-8-decoded.
|
|
|
|
Raises an exception if the request has not yet completed.
|
|
"""
|
|
if not self.is_finished():
|
|
raise Exception("Request not yet completed")
|
|
return self.result["body"].decode("utf8")
|
|
|
|
def is_finished(self) -> bool:
|
|
"""check if the response has been completely received"""
|
|
return self.result.get("done", False)
|
|
|
|
@property
|
|
def code(self) -> int:
|
|
if not self.result:
|
|
raise Exception("No result yet.")
|
|
return int(self.result["code"])
|
|
|
|
@property
|
|
def headers(self) -> Headers:
|
|
if not self.result:
|
|
raise Exception("No result yet.")
|
|
h = Headers()
|
|
for i in self.result["headers"]:
|
|
h.addRawHeader(*i)
|
|
return h
|
|
|
|
def writeHeaders(
|
|
self, version: bytes, code: bytes, reason: bytes, headers: Headers
|
|
) -> None:
|
|
self.result["version"] = version
|
|
self.result["code"] = code
|
|
self.result["reason"] = reason
|
|
self.result["headers"] = headers
|
|
|
|
def write(self, data: bytes) -> None:
|
|
assert isinstance(data, bytes), "Should be bytes! " + repr(data)
|
|
|
|
if "body" not in self.result:
|
|
self.result["body"] = b""
|
|
|
|
self.result["body"] += data
|
|
|
|
def writeSequence(self, data: Iterable[bytes]) -> None:
|
|
for x in data:
|
|
self.write(x)
|
|
|
|
def loseConnection(self) -> None:
|
|
self.unregisterProducer()
|
|
self.transport.loseConnection()
|
|
|
|
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
|
|
def registerProducer(self, producer: IProducer, streaming: bool) -> None:
|
|
# TODO This should ensure that the IProducer is an IPushProducer or
|
|
# IPullProducer, unfortunately twisted.protocols.basic.FileSender does
|
|
# implement those, but doesn't declare it.
|
|
self._producer = cast(Union[IPushProducer, IPullProducer], producer)
|
|
self.producerStreaming = streaming
|
|
|
|
def _produce() -> None:
|
|
if self._producer:
|
|
self._producer.resumeProducing()
|
|
self._reactor.callLater(0.1, _produce)
|
|
|
|
if not streaming:
|
|
self._reactor.callLater(0.0, _produce)
|
|
|
|
def unregisterProducer(self) -> None:
|
|
if self._producer is None:
|
|
return
|
|
|
|
self._producer = None
|
|
|
|
def stopProducing(self) -> None:
|
|
if self._producer is not None:
|
|
self._producer.stopProducing()
|
|
|
|
def pauseProducing(self) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def resumeProducing(self) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def requestDone(self, _self: Request) -> None:
|
|
self.result["done"] = True
|
|
if isinstance(_self, SynapseRequest):
|
|
assert _self.logcontext is not None
|
|
self.resource_usage = _self.logcontext.get_resource_usage()
|
|
|
|
def getPeer(self) -> IAddress:
|
|
# We give an address so that getClientAddress/getClientIP returns a non null entry,
|
|
# causing us to record the MAU
|
|
return address.IPv4Address("TCP", self._ip, 3423)
|
|
|
|
def getHost(self) -> IAddress:
|
|
# this is called by Request.__init__ to configure Request.host.
|
|
return address.IPv4Address("TCP", "127.0.0.1", 8888)
|
|
|
|
def isSecure(self) -> bool:
|
|
return False
|
|
|
|
@property
|
|
def transport(self) -> "FakeChannel":
|
|
return self
|
|
|
|
def await_result(self, timeout_ms: int = 1000) -> None:
|
|
"""
|
|
Wait until the request is finished.
|
|
"""
|
|
end_time = self._reactor.seconds() + timeout_ms / 1000.0
|
|
self._reactor.run()
|
|
|
|
while not self.is_finished():
|
|
# If there's a producer, tell it to resume producing so we get content
|
|
if self._producer:
|
|
self._producer.resumeProducing()
|
|
|
|
if self._reactor.seconds() > end_time:
|
|
raise TimedOutException("Timed out waiting for request to finish.")
|
|
|
|
self._reactor.advance(0.1)
|
|
|
|
def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
|
|
"""Process the contents of any Set-Cookie headers in the response
|
|
|
|
Any cookines found are added to the given dict
|
|
"""
|
|
headers = self.headers.getRawHeaders("Set-Cookie")
|
|
if not headers:
|
|
return
|
|
|
|
for h in headers:
|
|
parts = h.split(";")
|
|
k, v = parts[0].split("=", maxsplit=1)
|
|
cookies[k] = v
|
|
|
|
|
|
class FakeSite:
|
|
"""
|
|
A fake Twisted Web Site, with mocks of the extra things that
|
|
Synapse adds.
|
|
"""
|
|
|
|
server_version_string = b"1"
|
|
site_tag = "test"
|
|
access_logger = logging.getLogger("synapse.access.http.fake")
|
|
|
|
def __init__(
|
|
self,
|
|
resource: IResource,
|
|
reactor: IReactorTime,
|
|
experimental_cors_msc3886: bool = False,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
resource: the resource to be used for rendering all requests
|
|
"""
|
|
self._resource = resource
|
|
self.reactor = reactor
|
|
self.experimental_cors_msc3886 = experimental_cors_msc3886
|
|
|
|
def getResourceFor(self, request: Request) -> IResource:
|
|
return self._resource
|
|
|
|
|
|
def make_request(
|
|
reactor: MemoryReactorClock,
|
|
site: Union[Site, FakeSite],
|
|
method: Union[bytes, str],
|
|
path: Union[bytes, str],
|
|
content: Union[bytes, str, JsonDict] = b"",
|
|
access_token: Optional[str] = None,
|
|
request: Type[Request] = SynapseRequest,
|
|
shorthand: bool = True,
|
|
federation_auth_origin: Optional[bytes] = None,
|
|
content_is_form: bool = False,
|
|
await_result: bool = True,
|
|
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
|
|
client_ip: str = "127.0.0.1",
|
|
) -> FakeChannel:
|
|
"""
|
|
Make a web request using the given method, path and content, and render it
|
|
|
|
Returns the fake Channel object which records the response to the request.
|
|
|
|
Args:
|
|
reactor:
|
|
site: The twisted Site to use to render the request
|
|
method: The HTTP request method ("verb").
|
|
path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
|
|
content: The body of the request. JSON-encoded, if a str of bytes.
|
|
access_token: The access token to add as authorization for the request.
|
|
request: The request class to create.
|
|
shorthand: Whether to try and be helpful and prefix the given URL
|
|
with the usual REST API path, if it doesn't contain it.
|
|
federation_auth_origin: if set to not-None, we will add a fake
|
|
Authorization header pretenting to be the given server name.
|
|
content_is_form: Whether the content is URL encoded form data. Adds the
|
|
'Content-Type': 'application/x-www-form-urlencoded' header.
|
|
await_result: whether to wait for the request to complete rendering. If true,
|
|
will pump the reactor until the the renderer tells the channel the request
|
|
is finished.
|
|
custom_headers: (name, value) pairs to add as request headers
|
|
client_ip: The IP to use as the requesting IP. Useful for testing
|
|
ratelimiting.
|
|
|
|
Returns:
|
|
channel
|
|
"""
|
|
if not isinstance(method, bytes):
|
|
method = method.encode("ascii")
|
|
|
|
if not isinstance(path, bytes):
|
|
path = path.encode("ascii")
|
|
|
|
# Decorate it to be the full path, if we're using shorthand
|
|
if (
|
|
shorthand
|
|
and not path.startswith(b"/_matrix")
|
|
and not path.startswith(b"/_synapse")
|
|
):
|
|
if path.startswith(b"/"):
|
|
path = path[1:]
|
|
path = b"/_matrix/client/r0/" + path
|
|
|
|
if not path.startswith(b"/"):
|
|
path = b"/" + path
|
|
|
|
if isinstance(content, dict):
|
|
content = json.dumps(content).encode("utf8")
|
|
if isinstance(content, str):
|
|
content = content.encode("utf8")
|
|
|
|
channel = FakeChannel(site, reactor, ip=client_ip)
|
|
|
|
req = request(channel, site)
|
|
channel.request = req
|
|
|
|
req.content = BytesIO(content)
|
|
# Twisted expects to be at the end of the content when parsing the request.
|
|
req.content.seek(0, SEEK_END)
|
|
|
|
# Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
|
|
# bodies if the Content-Length header is missing
|
|
req.requestHeaders.addRawHeader(
|
|
b"Content-Length", str(len(content)).encode("ascii")
|
|
)
|
|
|
|
if access_token:
|
|
req.requestHeaders.addRawHeader(
|
|
b"Authorization", b"Bearer " + access_token.encode("ascii")
|
|
)
|
|
|
|
if federation_auth_origin is not None:
|
|
req.requestHeaders.addRawHeader(
|
|
b"Authorization",
|
|
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
|
|
)
|
|
|
|
if content:
|
|
if content_is_form:
|
|
req.requestHeaders.addRawHeader(
|
|
b"Content-Type", b"application/x-www-form-urlencoded"
|
|
)
|
|
else:
|
|
# Assume the body is JSON
|
|
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
|
|
|
|
if custom_headers:
|
|
for k, v in custom_headers:
|
|
req.requestHeaders.addRawHeader(k, v)
|
|
|
|
req.parseCookies()
|
|
req.requestReceived(method, path, b"1.1")
|
|
|
|
if await_result:
|
|
channel.await_result()
|
|
|
|
return channel
|
|
|
|
|
|
# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
|
|
# marking this as an implementer of the latter seems to keep mypy-zope happier.
|
|
@implementer(IReactorPluggableNameResolver, ISynapseReactor)
|
|
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|
"""
|
|
A MemoryReactorClock that supports callFromThread.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.threadpool = ThreadPool(self)
|
|
|
|
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
|
|
self._udp: List[udp.Port] = []
|
|
self.lookups: Dict[str, str] = {}
|
|
self._thread_callbacks: Deque[Callable[..., R]] = deque()
|
|
|
|
lookups = self.lookups
|
|
|
|
@implementer(IResolverSimple)
|
|
class FakeResolver:
|
|
def getHostByName(
|
|
self, name: str, timeout: Optional[Sequence[int]] = None
|
|
) -> "Deferred[str]":
|
|
if name not in lookups:
|
|
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
|
return succeed(lookups[name])
|
|
|
|
# In order for the TLS protocol tests to work, modify _get_default_clock
|
|
# on newer Twisted versions to use the test reactor's clock.
|
|
#
|
|
# This is *super* dirty since it is never undone and relies on the next
|
|
# test to overwrite it.
|
|
if twisted.version > Version("Twisted", 23, 8, 0):
|
|
from twisted.protocols import tls
|
|
|
|
tls._get_default_clock = lambda: self
|
|
|
|
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
|
|
super().__init__()
|
|
|
|
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
|
|
raise NotImplementedError()
|
|
|
|
def listenUDP(
|
|
self,
|
|
port: int,
|
|
protocol: DatagramProtocol,
|
|
interface: str = "",
|
|
maxPacketSize: int = 8196,
|
|
) -> udp.Port:
|
|
p = udp.Port(port, protocol, interface, maxPacketSize, self)
|
|
p.startListening()
|
|
self._udp.append(p)
|
|
return p
|
|
|
|
def callFromThread(
|
|
self, callable: Callable[..., Any], *args: object, **kwargs: object
|
|
) -> None:
|
|
"""
|
|
Make the callback fire in the next reactor iteration.
|
|
"""
|
|
cb = lambda: callable(*args, **kwargs)
|
|
# it's not safe to call callLater() here, so we append the callback to a
|
|
# separate queue.
|
|
self._thread_callbacks.append(cb)
|
|
|
|
def callInThread(
|
|
self, callable: Callable[..., Any], *args: object, **kwargs: object
|
|
) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def suggestThreadPoolSize(self, size: int) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def getThreadPool(self) -> "threadpool.ThreadPool":
|
|
# Cast to match super-class.
|
|
return cast(threadpool.ThreadPool, self.threadpool)
|
|
|
|
def add_tcp_client_callback(
|
|
self, host: str, port: int, callback: Callable[[], None]
|
|
) -> None:
|
|
"""Add a callback that will be invoked when we receive a connection
|
|
attempt to the given IP/port using `connectTCP`.
|
|
|
|
Note that the callback gets run before we return the connection to the
|
|
client, which means callbacks cannot block while waiting for writes.
|
|
"""
|
|
self._tcp_callbacks[(host, port)] = callback
|
|
|
|
def connectUNIX(
|
|
self,
|
|
address: str,
|
|
factory: ClientFactory,
|
|
timeout: float = 30,
|
|
checkPID: int = 0,
|
|
) -> IConnector:
|
|
"""
|
|
Unix sockets aren't supported for unit tests yet. Make it obvious to any
|
|
developer trying it out that they will need to do some work before being able
|
|
to use it in tests.
|
|
"""
|
|
raise Exception("Unix sockets are not implemented for tests yet, sorry.")
|
|
|
|
def listenUNIX(
|
|
self,
|
|
address: str,
|
|
factory: Factory,
|
|
backlog: int = 50,
|
|
mode: int = 0o666,
|
|
wantPID: int = 0,
|
|
) -> IListeningPort:
|
|
"""
|
|
Unix sockets aren't supported for unit tests yet. Make it obvious to any
|
|
developer trying it out that they will need to do some work before being able
|
|
to use it in tests.
|
|
"""
|
|
raise Exception("Unix sockets are not implemented for tests, sorry")
|
|
|
|
def connectTCP(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
factory: ClientFactory,
|
|
timeout: float = 30,
|
|
bindAddress: Optional[Tuple[str, int]] = None,
|
|
) -> IConnector:
|
|
"""Fake L{IReactorTCP.connectTCP}."""
|
|
|
|
conn = super().connectTCP(
|
|
host, port, factory, timeout=timeout, bindAddress=None
|
|
)
|
|
if self.lookups and host in self.lookups:
|
|
validate_connector(conn, self.lookups[host])
|
|
|
|
callback = self._tcp_callbacks.get((host, port))
|
|
if callback:
|
|
callback()
|
|
|
|
return conn
|
|
|
|
def advance(self, amount: float) -> None:
|
|
# first advance our reactor's time, and run any "callLater" callbacks that
|
|
# makes ready
|
|
super().advance(amount)
|
|
|
|
# now run any "callFromThread" callbacks
|
|
while True:
|
|
try:
|
|
callback = self._thread_callbacks.popleft()
|
|
except IndexError:
|
|
break
|
|
callback()
|
|
|
|
# check for more "callLater" callbacks added by the thread callback
|
|
# This isn't required in a regular reactor, but it ends up meaning that
|
|
# our database queries can complete in a single call to `advance` [1] which
|
|
# simplifies tests.
|
|
#
|
|
# [1]: we replace the threadpool backing the db connection pool with a
|
|
# mock ThreadPool which doesn't really use threads; but we still use
|
|
# reactor.callFromThread to feed results back from the db functions to the
|
|
# main thread.
|
|
super().advance(0)
|
|
|
|
|
|
def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
|
|
"""Try to validate the obtained connector as it would happen when
|
|
synapse is running and the conection will be established.
|
|
|
|
This method will raise a useful exception when necessary, else it will
|
|
just do nothing.
|
|
|
|
This is in order to help catch quirks related to reactor.connectTCP,
|
|
since when called directly, the connector's destination will be of type
|
|
IPv4Address, with the hostname as the literal host that was given (which
|
|
could be an IPv6-only host or an IPv6 literal).
|
|
|
|
But when called from reactor.connectTCP *through* e.g. an Endpoint, the
|
|
connector's destination will contain the specific IP address with the
|
|
correct network stack class.
|
|
|
|
Note that testing code paths that use connectTCP directly should not be
|
|
affected by this check, unless they specifically add a test with a
|
|
matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
|
|
type ThreadedMemoryReactorClock.
|
|
For an example of implementing such tests, see test/handlers/send_email.py.
|
|
"""
|
|
destination = connector.getDestination()
|
|
|
|
# We use address.IPv{4,6}Address to check what the reactor thinks it is
|
|
# is sending but check for validity with ipaddress.IPv{4,6}Address
|
|
# because they fail with IPs on the wrong network stack.
|
|
cls_mapping = {
|
|
address.IPv4Address: ipaddress.IPv4Address,
|
|
address.IPv6Address: ipaddress.IPv6Address,
|
|
}
|
|
|
|
cls = cls_mapping.get(destination.__class__)
|
|
|
|
if cls is not None:
|
|
try:
|
|
cls(expected_ip)
|
|
except Exception as exc:
|
|
raise ValueError(
|
|
"Invalid IP type and resolution for %s. Expected %s to be %s"
|
|
% (destination, expected_ip, cls.__name__)
|
|
) from exc
|
|
else:
|
|
raise ValueError(
|
|
"Unknown address type %s for %s"
|
|
% (destination.__class__.__name__, destination)
|
|
)
|
|
|
|
|
|
def make_fake_db_pool(
|
|
reactor: ISynapseReactor,
|
|
db_config: DatabaseConnectionConfig,
|
|
engine: BaseDatabaseEngine,
|
|
) -> adbapi.ConnectionPool:
|
|
"""Wrapper for `make_pool` which builds a pool which runs db queries synchronously.
|
|
|
|
For more deterministic testing, we don't use a regular db connection pool: instead
|
|
we run all db queries synchronously on the test reactor's main thread. This function
|
|
is a drop-in replacement for the normal `make_pool` which builds such a connection
|
|
pool.
|
|
"""
|
|
pool = make_pool(reactor, db_config, engine)
|
|
|
|
def runWithConnection(
|
|
func: Callable[..., R], *args: Any, **kwargs: Any
|
|
) -> Awaitable[R]:
|
|
return threads.deferToThreadPool(
|
|
pool._reactor,
|
|
pool.threadpool,
|
|
pool._runWithConnection,
|
|
func,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
def runInteraction(
|
|
desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
|
|
) -> Awaitable[R]:
|
|
return threads.deferToThreadPool(
|
|
pool._reactor,
|
|
pool.threadpool,
|
|
pool._runInteraction,
|
|
desc,
|
|
func,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
pool.runWithConnection = runWithConnection # type: ignore[method-assign]
|
|
pool.runInteraction = runInteraction # type: ignore[assignment]
|
|
# Replace the thread pool with a threadless 'thread' pool
|
|
pool.threadpool = ThreadPool(reactor)
|
|
pool.running = True
|
|
return pool
|
|
|
|
|
|
class ThreadPool:
|
|
"""
|
|
Threadless thread pool.
|
|
|
|
See twisted.python.threadpool.ThreadPool
|
|
"""
|
|
|
|
def __init__(self, reactor: IReactorTime):
|
|
self._reactor = reactor
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self) -> None:
|
|
pass
|
|
|
|
def callInThreadWithCallback(
|
|
self,
|
|
onResult: Callable[[bool, Union[Failure, R]], None],
|
|
function: Callable[P, R],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> "Deferred[None]":
|
|
def _(res: Any) -> None:
|
|
if isinstance(res, Failure):
|
|
onResult(False, res)
|
|
else:
|
|
onResult(True, res)
|
|
|
|
d: "Deferred[None]" = Deferred()
|
|
d.addCallback(lambda x: function(*args, **kwargs))
|
|
d.addBoth(_)
|
|
self._reactor.callLater(0, d.callback, True)
|
|
return d
|
|
|
|
|
|
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
|
|
clock = ThreadedMemoryReactorClock()
|
|
hs_clock = Clock(clock)
|
|
return clock, hs_clock
|
|
|
|
|
|
@implementer(ITransport)
|
|
@attr.s(cmp=False, auto_attribs=True)
|
|
class FakeTransport:
|
|
"""
|
|
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
|
straight into an IProtocol object: it exists to connect two IProtocols together.
|
|
|
|
To use it, instantiate it with the receiving IProtocol, and then pass it to the
|
|
sending IProtocol's makeConnection method:
|
|
|
|
server = HTTPChannel()
|
|
client.makeConnection(FakeTransport(server, self.reactor))
|
|
|
|
If you want bidirectional communication, you'll need two instances.
|
|
"""
|
|
|
|
other: IProtocol
|
|
"""The Protocol object which will receive any data written to this transport.
|
|
"""
|
|
|
|
_reactor: IReactorTime
|
|
"""Test reactor
|
|
"""
|
|
|
|
_protocol: Optional[IProtocol] = None
|
|
"""The Protocol which is producing data for this transport. Optional, but if set
|
|
will get called back for connectionLost() notifications etc.
|
|
"""
|
|
|
|
_peer_address: IAddress = attr.Factory(
|
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
|
|
)
|
|
"""The value to be returned by getPeer"""
|
|
|
|
_host_address: IAddress = attr.Factory(
|
|
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
|
|
)
|
|
"""The value to be returned by getHost"""
|
|
|
|
disconnecting = False
|
|
disconnected = False
|
|
connected = True
|
|
buffer: bytes = b""
|
|
producer: Optional[IPushProducer] = None
|
|
autoflush: bool = True
|
|
|
|
def getPeer(self) -> IAddress:
|
|
return self._peer_address
|
|
|
|
def getHost(self) -> IAddress:
|
|
return self._host_address
|
|
|
|
def loseConnection(self) -> None:
|
|
if not self.disconnecting:
|
|
logger.info("FakeTransport: loseConnection()")
|
|
self.disconnecting = True
|
|
if self._protocol:
|
|
self._protocol.connectionLost(
|
|
Failure(RuntimeError("FakeTransport.loseConnection()"))
|
|
)
|
|
|
|
# if we still have data to write, delay until that is done
|
|
if self.buffer:
|
|
logger.info(
|
|
"FakeTransport: Delaying disconnect until buffer is flushed"
|
|
)
|
|
else:
|
|
self.connected = False
|
|
self.disconnected = True
|
|
|
|
def abortConnection(self) -> None:
|
|
logger.info("FakeTransport: abortConnection()")
|
|
|
|
if not self.disconnecting:
|
|
self.disconnecting = True
|
|
if self._protocol:
|
|
self._protocol.connectionLost(None) # type: ignore[arg-type]
|
|
|
|
self.disconnected = True
|
|
|
|
def pauseProducing(self) -> None:
|
|
if not self.producer:
|
|
return
|
|
|
|
self.producer.pauseProducing()
|
|
|
|
def resumeProducing(self) -> None:
|
|
if not self.producer:
|
|
return
|
|
self.producer.resumeProducing()
|
|
|
|
def unregisterProducer(self) -> None:
|
|
if not self.producer:
|
|
return
|
|
|
|
self.producer = None
|
|
|
|
def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
|
|
self.producer = producer
|
|
self.producerStreaming = streaming
|
|
|
|
def _produce() -> None:
|
|
if not self.producer:
|
|
# we've been unregistered
|
|
return
|
|
# some implementations of IProducer (for example, FileSender)
|
|
# don't return a deferred.
|
|
d = maybeDeferred(self.producer.resumeProducing)
|
|
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
|
|
|
|
if not streaming:
|
|
self._reactor.callLater(0.0, _produce)
|
|
|
|
def write(self, byt: bytes) -> None:
|
|
if self.disconnecting:
|
|
raise Exception("Writing to disconnecting FakeTransport")
|
|
|
|
self.buffer = self.buffer + byt
|
|
|
|
# always actually do the write asynchronously. Some protocols (notably the
|
|
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
|
|
# still doing a write. Doing a callLater here breaks the cycle.
|
|
if self.autoflush:
|
|
self._reactor.callLater(0.0, self.flush)
|
|
|
|
def writeSequence(self, seq: Iterable[bytes]) -> None:
|
|
for x in seq:
|
|
self.write(x)
|
|
|
|
def flush(self, maxbytes: Optional[int] = None) -> None:
|
|
if not self.buffer:
|
|
# nothing to do. Don't write empty buffers: it upsets the
|
|
# TLSMemoryBIOProtocol
|
|
return
|
|
|
|
if self.disconnected:
|
|
return
|
|
|
|
if maxbytes is not None:
|
|
to_write = self.buffer[:maxbytes]
|
|
else:
|
|
to_write = self.buffer
|
|
|
|
logger.info("%s->%s: %s", self._protocol, self.other, to_write)
|
|
|
|
try:
|
|
self.other.dataReceived(to_write)
|
|
except Exception as e:
|
|
logger.exception("Exception writing to protocol: %s", e)
|
|
return
|
|
|
|
self.buffer = self.buffer[len(to_write) :]
|
|
if self.buffer and self.autoflush:
|
|
self._reactor.callLater(0.0, self.flush)
|
|
|
|
if not self.buffer and self.disconnecting:
|
|
logger.info("FakeTransport: Buffer now empty, completing disconnect")
|
|
self.disconnected = True
|
|
|
|
|
|
def connect_client(
|
|
reactor: ThreadedMemoryReactorClock, client_id: int
|
|
) -> Tuple[IProtocol, AccumulatingProtocol]:
|
|
"""
|
|
Connect a client to a fake TCP transport.
|
|
|
|
Args:
|
|
reactor
|
|
factory: The connecting factory to build.
|
|
"""
|
|
factory = reactor.tcpClients.pop(client_id)[2]
|
|
client = factory.buildProtocol(None)
|
|
server = AccumulatingProtocol()
|
|
server.makeConnection(FakeTransport(client, reactor))
|
|
client.makeConnection(FakeTransport(server, reactor))
|
|
|
|
return client, server
|
|
|
|
|
|
class TestHomeServer(HomeServer):
|
|
DATASTORE_CLASS = DataStore # type: ignore[assignment]
|
|
|
|
|
|
def setup_test_homeserver(
|
|
cleanup_func: Callable[[Callable[[], None]], None],
|
|
name: str = "test",
|
|
config: Optional[HomeServerConfig] = None,
|
|
reactor: Optional[ISynapseReactor] = None,
|
|
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
|
**kwargs: Any,
|
|
) -> HomeServer:
|
|
"""
|
|
Setup a homeserver suitable for running tests against. Keyword arguments
|
|
are passed to the Homeserver constructor.
|
|
|
|
If no datastore is supplied, one is created and given to the homeserver.
|
|
|
|
Args:
|
|
cleanup_func : The function used to register a cleanup routine for
|
|
after the test.
|
|
|
|
Calling this method directly is deprecated: you should instead derive from
|
|
HomeserverTestCase.
|
|
"""
|
|
if reactor is None:
|
|
from twisted.internet import reactor as _reactor
|
|
|
|
reactor = cast(ISynapseReactor, _reactor)
|
|
|
|
if config is None:
|
|
config = default_config(name, parse=True)
|
|
|
|
config.caches.resize_all_caches()
|
|
|
|
if "clock" not in kwargs:
|
|
kwargs["clock"] = MockClock()
|
|
|
|
if USE_POSTGRES_FOR_TESTS:
|
|
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
|
|
|
database_config = {
|
|
"name": "psycopg2",
|
|
"args": {
|
|
"dbname": test_db,
|
|
"host": POSTGRES_HOST,
|
|
"password": POSTGRES_PASSWORD,
|
|
"user": POSTGRES_USER,
|
|
"port": POSTGRES_PORT,
|
|
"cp_min": 1,
|
|
"cp_max": 5,
|
|
},
|
|
}
|
|
else:
|
|
if SQLITE_PERSIST_DB:
|
|
# The current working directory is in _trial_temp, so this gets created within that directory.
|
|
test_db_location = os.path.abspath("test.db")
|
|
logger.debug("Will persist db to %s", test_db_location)
|
|
# Ensure each test gets a clean database.
|
|
try:
|
|
os.remove(test_db_location)
|
|
except FileNotFoundError:
|
|
pass
|
|
else:
|
|
logger.debug("Removed existing DB at %s", test_db_location)
|
|
else:
|
|
test_db_location = ":memory:"
|
|
|
|
database_config = {
|
|
"name": "sqlite3",
|
|
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
|
|
}
|
|
|
|
# Check if we have set up a DB that we can use as a template.
|
|
global PREPPED_SQLITE_DB_CONN
|
|
if PREPPED_SQLITE_DB_CONN is None:
|
|
temp_engine = create_engine(database_config)
|
|
PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
|
|
sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
|
|
)
|
|
|
|
database = DatabaseConnectionConfig("master", database_config)
|
|
config.database.databases = [database]
|
|
prepare_database(
|
|
PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
|
|
)
|
|
|
|
database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
|
|
|
|
if "db_txn_limit" in kwargs:
|
|
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
|
|
|
database = DatabaseConnectionConfig("master", database_config)
|
|
config.database.databases = [database]
|
|
|
|
db_engine = create_engine(database.config)
|
|
|
|
# Create the database before we actually try and connect to it, based off
|
|
# the template database we generate in setupdb()
|
|
if USE_POSTGRES_FOR_TESTS:
|
|
db_conn = db_engine.module.connect(
|
|
dbname=POSTGRES_BASE_DB,
|
|
user=POSTGRES_USER,
|
|
host=POSTGRES_HOST,
|
|
port=POSTGRES_PORT,
|
|
password=POSTGRES_PASSWORD,
|
|
)
|
|
db_engine.attempt_to_set_autocommit(db_conn, True)
|
|
cur = db_conn.cursor()
|
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
|
cur.execute(
|
|
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
|
|
)
|
|
cur.close()
|
|
db_conn.close()
|
|
|
|
hs = homeserver_to_use(
|
|
name,
|
|
config=config,
|
|
version_string="Synapse/tests",
|
|
reactor=reactor,
|
|
)
|
|
|
|
# Install @cache_in_self attributes
|
|
for key, val in kwargs.items():
|
|
setattr(hs, "_" + key, val)
|
|
|
|
# Mock TLS
|
|
hs.tls_server_context_factory = Mock()
|
|
|
|
# Patch `make_pool` before initialising the database, to make database transactions
|
|
# synchronous for testing.
|
|
with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool):
|
|
hs.setup()
|
|
|
|
# Since we've changed the databases to run DB transactions on the same
|
|
# thread, we need to stop the event fetcher hogging that one thread.
|
|
hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
|
|
|
|
if USE_POSTGRES_FOR_TESTS:
|
|
database_pool = hs.get_datastores().databases[0]
|
|
|
|
# We need to do cleanup on PostgreSQL
|
|
def cleanup() -> None:
|
|
import psycopg2
|
|
|
|
# Close all the db pools
|
|
database_pool._db_pool.close()
|
|
|
|
dropped = False
|
|
|
|
# Drop the test database
|
|
db_conn = db_engine.module.connect(
|
|
dbname=POSTGRES_BASE_DB,
|
|
user=POSTGRES_USER,
|
|
host=POSTGRES_HOST,
|
|
port=POSTGRES_PORT,
|
|
password=POSTGRES_PASSWORD,
|
|
)
|
|
db_engine.attempt_to_set_autocommit(db_conn, True)
|
|
cur = db_conn.cursor()
|
|
|
|
# Try a few times to drop the DB. Some things may hold on to the
|
|
# database for a few more seconds due to flakiness, preventing
|
|
# us from dropping it when the test is over. If we can't drop
|
|
# it, warn and move on.
|
|
for _ in range(5):
|
|
try:
|
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
|
db_conn.commit()
|
|
dropped = True
|
|
except psycopg2.OperationalError as e:
|
|
warnings.warn(
|
|
"Couldn't drop old db: " + str(e),
|
|
category=UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
time.sleep(0.5)
|
|
|
|
cur.close()
|
|
db_conn.close()
|
|
|
|
if not dropped:
|
|
warnings.warn(
|
|
"Failed to drop old DB.",
|
|
category=UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
if not LEAVE_DB:
|
|
# Register the cleanup hook
|
|
cleanup_func(cleanup)
|
|
|
|
# bcrypt is far too slow to be doing in unit tests
|
|
# Need to let the HS build an auth handler and then mess with it
|
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
|
async def hash(p: str) -> str:
|
|
return hashlib.md5(p.encode("utf8")).hexdigest()
|
|
|
|
hs.get_auth_handler().hash = hash # type: ignore[assignment]
|
|
|
|
async def validate_hash(p: str, h: str) -> bool:
|
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
|
|
|
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
|
|
|
|
# Load any configured modules into the homeserver
|
|
module_api = hs.get_module_api()
|
|
for module, module_config in hs.config.modules.loaded_modules:
|
|
module(config=module_config, api=module_api)
|
|
|
|
load_legacy_spam_checkers(hs)
|
|
load_legacy_third_party_event_rules(hs)
|
|
load_legacy_presence_router(hs)
|
|
load_legacy_password_auth_providers(hs)
|
|
|
|
return hs
|