mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 15:01:23 +01:00
Run on_new_connection for unit tests
Configure the connectionpool used for unit tests to run the `on_new_connection` function.
This commit is contained in:
parent
d8f90c4208
commit
b178eca261
1 changed files with 17 additions and 7 deletions
|
@ -66,13 +66,19 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
if "clock" not in kargs:
|
if "clock" not in kargs:
|
||||||
kargs["clock"] = MockClock()
|
kargs["clock"] = MockClock()
|
||||||
|
|
||||||
|
db_engine = create_engine(config.database_config)
|
||||||
if datastore is None:
|
if datastore is None:
|
||||||
db_pool = SQLiteMemoryDbPool()
|
# we need to configure the connection pool to run the on_new_connection
|
||||||
|
# function, so that we can test code that uses custom sqlite functions
|
||||||
|
# (like rank).
|
||||||
|
db_pool = SQLiteMemoryDbPool(
|
||||||
|
cp_openfun=db_engine.on_new_connection,
|
||||||
|
)
|
||||||
yield db_pool.prepare()
|
yield db_pool.prepare()
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=db_pool, config=config,
|
name, db_pool=db_pool, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
get_db_conn=db_pool.get_db_conn,
|
get_db_conn=db_pool.get_db_conn,
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
|
@ -83,7 +89,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
hs = HomeServer(
|
hs = HomeServer(
|
||||||
name, db_pool=None, datastore=datastore, config=config,
|
name, db_pool=None, datastore=datastore, config=config,
|
||||||
version_string="Synapse/tests",
|
version_string="Synapse/tests",
|
||||||
database_engine=create_engine(config.database_config),
|
database_engine=db_engine,
|
||||||
room_list_handler=object(),
|
room_list_handler=object(),
|
||||||
tls_server_context_factory=Mock(),
|
tls_server_context_factory=Mock(),
|
||||||
**kargs
|
**kargs
|
||||||
|
@ -303,11 +309,15 @@ class MockClock(object):
|
||||||
|
|
||||||
|
|
||||||
class SQLiteMemoryDbPool(ConnectionPool, object):
|
class SQLiteMemoryDbPool(ConnectionPool, object):
|
||||||
def __init__(self):
|
def __init__(self, **kwargs):
|
||||||
|
connkw = {
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 1,
|
||||||
|
}
|
||||||
|
connkw.update(kwargs)
|
||||||
|
|
||||||
super(SQLiteMemoryDbPool, self).__init__(
|
super(SQLiteMemoryDbPool, self).__init__(
|
||||||
"sqlite3", ":memory:",
|
"sqlite3", ":memory:", **connkw
|
||||||
cp_min=1,
|
|
||||||
cp_max=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config = Mock()
|
self.config = Mock()
|
||||||
|
|
Loading…
Reference in a new issue