0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 09:24:04 +01:00

Move tests.utils.setup_test_homeserver to tests.server

It had no users.

We have just taken the identity of a previous function but don't provide the same
behaviour, so we need to fix this in the next commit...
This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-12-03 11:37:21 +00:00
parent f7ec6e7d9e
commit b3fd99b74a
4 changed files with 188 additions and 177 deletions

View file

@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import json import json
import logging import logging
import time
import uuid
import warnings
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import ( from typing import (
@ -27,6 +30,7 @@ from typing import (
Type, Type,
Union, Union,
) )
from unittest.mock import Mock
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -53,10 +57,24 @@ from twisted.web.http_headers import Headers
from twisted.web.resource import IResource from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.utils import (
LEAVE_DB,
POSTGRES_BASE_DB,
POSTGRES_HOST,
POSTGRES_PASSWORD,
POSTGRES_USER,
USE_POSTGRES_FOR_TESTS,
MockClock,
default_config,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -668,3 +686,168 @@ def connect_client(
client.makeConnection(FakeTransport(server, reactor)) client.makeConnection(FakeTransport(server, reactor))
return client, server return client, server
class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
If no datastore is supplied, one is created and given to the homeserver.
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
if config is None:
config = default_config(name, parse=True)
config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"user": POSTGRES_USER,
"cp_min": 1,
"cp_max": 5,
},
}
else:
database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
cur.execute(
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
)
cur.close()
db_conn.close()
hs = homeserver_to_use(
name,
config=config,
version_string="Synapse/tests",
reactor=reactor,
)
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
database._db_pool.close()
dropped = False
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
# Try a few times to drop the DB. Some things may hold on to the
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
for _ in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e), category=UserWarning
)
time.sleep(0.5)
cur.close()
db_conn.close()
if not dropped:
warnings.warn("Failed to drop old DB.", category=UserWarning)
if not LEAVE_DB:
# Register the cleanup hook
cleanup_func(cleanup)
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
return hs

View file

@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from tests import unittest from tests import unittest
from tests.utils import TestHomeServer, default_config from tests.server import TestHomeServer
from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase): class SQLBaseStoreTestCase(unittest.TestCase):

View file

@ -19,8 +19,8 @@ from synapse.rest.client import login, room
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from tests import unittest from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection from tests.test_utils import event_injection
from tests.utils import TestHomeServer
class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class RoomMemberStoreTestCase(unittest.HomeserverTestCase):

View file

@ -14,12 +14,7 @@
# limitations under the License. # limitations under the License.
import atexit import atexit
import hashlib
import os import os
import time
import uuid
import warnings
from typing import Type
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from urllib import parse as urlparse from urllib import parse as urlparse
@ -28,14 +23,11 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
# set this to True to run the tests against postgres instead of sqlite. # set this to True to run the tests against postgres instead of sqlite.
@ -182,171 +174,6 @@ def default_config(name, parse=False):
return config_dict return config_dict
class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
If no datastore is supplied, one is created and given to the homeserver.
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
if config is None:
config = default_config(name, parse=True)
config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"user": POSTGRES_USER,
"cp_min": 1,
"cp_max": 5,
},
}
else:
database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
cur.execute(
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
)
cur.close()
db_conn.close()
hs = homeserver_to_use(
name,
config=config,
version_string="Synapse/tests",
reactor=reactor,
)
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
database._db_pool.close()
dropped = False
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
# Try a few times to drop the DB. Some things may hold on to the
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
for _ in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e), category=UserWarning
)
time.sleep(0.5)
cur.close()
db_conn.close()
if not dropped:
warnings.warn("Failed to drop old DB.", category=UserWarning)
if not LEAVE_DB:
# Register the cleanup hook
cleanup_func(cleanup)
# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
return hs
def mock_getRawHeaders(headers=None): def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {} headers = headers if headers is not None else {}