0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-24 03:59:06 +02:00

Merge pull request #2829 from matrix-org/rav/postgres_in_tests

Make it possible to run tests against postgres
This commit is contained in:
Richard van der Hoff 2018-01-27 18:20:55 +01:00 committed by GitHub
commit 2186d7c06e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 38 deletions

View file

@ -167,7 +167,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server # wait a tick for it to send the request to the perspectives server
# (it first tries the datastore) # (it first tries the datastore)
yield async.sleep(0.005) yield async.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once() self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11) self.assertIs(LoggingContext.current_context(), context_11)
@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)], [("server10", json1)],
) )
yield async.sleep(0.005) yield async.sleep(01)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None) res_deferreds_2[0].addBoth(self.check_context, None)

View file

@ -19,18 +19,23 @@ import urllib
import urlparse import urlparse
from mock import Mock, patch from mock import Mock, patch
from twisted.enterprise.adbapi import ConnectionPool
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.api.errors import CodeMessageException, cs_error from synapse.api.errors import CodeMessageException, cs_error
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import PostgresEngine
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
# It requires you to have a local postgres database called synapse_test, within
# which ALL TABLES WILL BE DROPPED
USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
@ -60,30 +65,62 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.update_user_directory = False config.update_user_directory = False
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False config.ldap_enabled = False
if "clock" not in kargs: if "clock" not in kargs:
kargs["clock"] = MockClock() kargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
config.database_config = {
"name": "psycopg2",
"args": {
"database": "synapse_test",
"cp_min": 1,
"cp_max": 5,
},
}
else:
config.database_config = {
"name": "sqlite3",
"args": {
"database": ":memory:",
"cp_min": 1,
"cp_max": 1,
},
}
db_engine = create_engine(config.database_config)
# we need to configure the connection pool to run the on_new_connection
# function, so that we can test code that uses custom sqlite functions
# (like rank).
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
if datastore is None: if datastore is None:
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, config=config,
db_config=config.database_config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=db_engine,
get_db_conn=db_pool.get_db_conn,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
**kargs **kargs
) )
db_conn = hs.get_db_conn()
# make sure that the database is empty
if isinstance(db_engine, PostgresEngine):
cur = db_conn.cursor()
cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
rows = cur.fetchall()
for r in rows:
cur.execute("DROP TABLE %s CASCADE" % r[0])
yield prepare_database(db_conn, db_engine, config)
hs.setup() hs.setup()
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine(config.database_config), database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
**kargs **kargs
@ -302,34 +339,6 @@ class MockClock(object):
return d return d
class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self):
super(SQLiteMemoryDbPool, self).__init__(
"sqlite3", ":memory:",
cp_min=1,
cp_max=1,
)
self.config = Mock()
self.config.password_providers = []
self.config.database_config = {"name": "sqlite3"}
def prepare(self):
engine = self.create_engine()
return self.runWithConnection(
lambda conn: prepare_database(conn, engine, self.config)
)
def get_db_conn(self):
conn = self.connect()
engine = self.create_engine()
prepare_database(conn, engine, self.config)
return conn
def create_engine(self):
return create_engine(self.config.database_config)
def _format_call(args, kwargs): def _format_call(args, kwargs):
return ", ".join( return ", ".join(
["%r" % (a) for a in args] + ["%r" % (a) for a in args] +