diff --git a/changelog.d/11503.misc b/changelog.d/11503.misc deleted file mode 100644 index 03a24a922..000000000 --- a/changelog.d/11503.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. \ No newline at end of file diff --git a/tests/server.py b/tests/server.py index b29df3759..40cf5b12c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hashlib + import json import logging -import time -import uuid -import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -30,7 +27,6 @@ from typing import ( Type, Union, ) -from unittest.mock import Mock import attr from typing_extensions import Deque @@ -57,24 +53,11 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site -from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import ( - LEAVE_DB, - POSTGRES_BASE_DB, - POSTGRES_HOST, - POSTGRES_PASSWORD, - POSTGRES_USER, - USE_POSTGRES_FOR_TESTS, - MockClock, - default_config, -) +from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -467,11 +450,14 @@ class ThreadPool: return d -def make_test_homeserver_synchronous(server: HomeServer) -> None: +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ - Make the given test homeserver's database interactions synchronous. + Set up a synchronous test server, driven by the reactor used by + the homeserver. """ + server = _sth(cleanup_func, *args, **kwargs) + # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -499,7 +485,6 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -507,6 +492,8 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server + def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() @@ -686,171 +673,3 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server - - -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - # Make the threadpool and database transactions synchronous for testing. - make_test_homeserver_synchronous(hs) - - return hs diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c..ddad44bd6 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -23,8 +23,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.server import TestHomeServer -from tests.utils import default_config +from tests.utils import TestHomeServer, default_config class SQLBaseStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b8..fccab733c 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,8 +19,8 @@ from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest -from tests.server import TestHomeServer from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 6d013e851..983859120 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,12 @@ # limitations under the License. import atexit +import hashlib import os +import time +import uuid +import warnings +from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -23,11 +28,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer +from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -174,6 +182,171 @@ def default_config(name, parse=False): return config_dict +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs + + def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {}