mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:23:51 +01:00
Merge pull request #534 from matrix-org/erikj/setup
Add a Homeserver.setup method
This commit is contained in:
commit
167d1df699
12 changed files with 123 additions and 120 deletions
|
@ -254,6 +254,18 @@ class SynapseHomeServer(HomeServer):
|
|||
except IncorrectDatabaseSetup as e:
|
||||
quit_with_error(e.message)
|
||||
|
||||
def get_db_conn(self):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
|
||||
def quit_with_error(error_string):
|
||||
message_lines = error_string.split("\n")
|
||||
|
@ -390,13 +402,7 @@ def setup(config_options):
|
|||
logger.info("Preparing database: %s...", config.database_config['name'])
|
||||
|
||||
try:
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
k: v for k, v in config.database_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
)
|
||||
|
||||
db_conn = hs.get_db_conn()
|
||||
database_engine.prepare_database(db_conn)
|
||||
hs.run_startup_checks(db_conn, database_engine)
|
||||
|
||||
|
@ -411,14 +417,18 @@ def setup(config_options):
|
|||
|
||||
logger.info("Database prepared in %s.", config.database_config['name'])
|
||||
|
||||
hs.setup()
|
||||
hs.start_listening()
|
||||
|
||||
def start():
|
||||
hs.get_pusherpool().start()
|
||||
hs.get_state_handler().start_caching()
|
||||
hs.get_datastore().start_profiling()
|
||||
hs.get_datastore().start_doing_background_updates()
|
||||
hs.get_replication_layer().start_get_pdu_cache()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
|
|
|
@ -40,6 +40,11 @@ from synapse.api.filtering import Filtering
|
|||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HomeServer(object):
|
||||
"""A basic homeserver object without lazy component builders.
|
||||
|
@ -102,10 +107,19 @@ class HomeServer(object):
|
|||
self.hostname = hostname
|
||||
self._building = {}
|
||||
|
||||
self.clock = Clock()
|
||||
self.distributor = Distributor()
|
||||
self.ratelimiter = Ratelimiter()
|
||||
|
||||
# Other kwargs are explicit dependencies
|
||||
for depname in kwargs:
|
||||
setattr(self, depname, kwargs[depname])
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = DataStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def get_ip_from_request(self, request):
|
||||
# X-Forwarded-For is handled by our custom request type.
|
||||
return request.getClientIP()
|
||||
|
@ -116,15 +130,9 @@ class HomeServer(object):
|
|||
def is_mine_id(self, string):
|
||||
return string.split(":", 1)[1] == self.hostname
|
||||
|
||||
def build_clock(self):
|
||||
return Clock()
|
||||
|
||||
def build_replication_layer(self):
|
||||
return initialize_http_replication(self)
|
||||
|
||||
def build_datastore(self):
|
||||
return DataStore(self)
|
||||
|
||||
def build_handlers(self):
|
||||
return Handlers(self)
|
||||
|
||||
|
@ -135,10 +143,9 @@ class HomeServer(object):
|
|||
return Auth(self)
|
||||
|
||||
def build_http_client_context_factory(self):
|
||||
config = self.get_config()
|
||||
return (
|
||||
InsecureInterceptableContextFactory()
|
||||
if config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
else BrowserLikePolicyForHTTPS()
|
||||
)
|
||||
|
||||
|
@ -157,15 +164,9 @@ class HomeServer(object):
|
|||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
def build_distributor(self):
|
||||
return Distributor()
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
def build_ratelimiter(self):
|
||||
return Ratelimiter()
|
||||
|
||||
def build_keyring(self):
|
||||
return Keyring(self)
|
||||
|
||||
|
|
|
@ -46,6 +46,9 @@ from .tags import TagsStore
|
|||
from .account_data import AccountDataStore
|
||||
|
||||
|
||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -79,18 +82,43 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
EventPushActionsStore
|
||||
):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DataStore, self).__init__(hs)
|
||||
def __init__(self, db_conn, hs):
|
||||
self.hs = hs
|
||||
|
||||
self.min_token_deferred = self._get_min_token()
|
||||
self.min_token = None
|
||||
cur = db_conn.cursor()
|
||||
try:
|
||||
cur.execute("SELECT MIN(stream_ordering) FROM events",)
|
||||
rows = cur.fetchall()
|
||||
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
|
||||
self.min_stream_token = min(self.min_stream_token, -1)
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
)
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering"
|
||||
)
|
||||
self._receipts_id_gen = StreamIdGenerator(
|
||||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
self._account_data_id_gen = StreamIdGenerator(
|
||||
db_conn, "account_data_max_stream_id", "stream_id"
|
||||
)
|
||||
|
||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||
|
||||
super(DataStore, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
||||
now = int(self._clock.time_msec())
|
||||
|
|
|
@ -15,13 +15,11 @@
|
|||
import logging
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
import synapse.metrics
|
||||
|
||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -175,16 +173,6 @@ class SQLBaseStore(object):
|
|||
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
|
||||
|
@ -345,7 +333,8 @@ class SQLBaseStore(object):
|
|||
|
||||
defer.returnValue(result)
|
||||
|
||||
def cursor_to_dict(self, cursor):
|
||||
@staticmethod
|
||||
def cursor_to_dict(cursor):
|
||||
"""Converts a SQL cursor into an list of dicts.
|
||||
|
||||
Args:
|
||||
|
@ -402,8 +391,8 @@ class SQLBaseStore(object):
|
|||
if not or_ignore:
|
||||
raise
|
||||
|
||||
@log_function
|
||||
def _simple_insert_txn(self, txn, table, values):
|
||||
@staticmethod
|
||||
def _simple_insert_txn(txn, table, values):
|
||||
keys, vals = zip(*values.items())
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
|
@ -414,7 +403,8 @@ class SQLBaseStore(object):
|
|||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def _simple_insert_many_txn(self, txn, table, values):
|
||||
@staticmethod
|
||||
def _simple_insert_many_txn(txn, table, values):
|
||||
if not values:
|
||||
return
|
||||
|
||||
|
@ -537,9 +527,10 @@ class SQLBaseStore(object):
|
|||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
|
||||
@classmethod
|
||||
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
ret = self._simple_select_onecol_txn(
|
||||
ret = cls._simple_select_onecol_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
|
@ -554,7 +545,8 @@ class SQLBaseStore(object):
|
|||
else:
|
||||
raise StoreError(404, "No row found")
|
||||
|
||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
||||
) % {
|
||||
|
@ -603,7 +595,8 @@ class SQLBaseStore(object):
|
|||
table, keyvalues, retcols
|
||||
)
|
||||
|
||||
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
|
||||
@classmethod
|
||||
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -627,7 +620,7 @@ class SQLBaseStore(object):
|
|||
)
|
||||
txn.execute(sql)
|
||||
|
||||
return self.cursor_to_dict(txn)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||
|
@ -662,7 +655,8 @@ class SQLBaseStore(object):
|
|||
|
||||
defer.returnValue(results)
|
||||
|
||||
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
|
||||
@classmethod
|
||||
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -699,7 +693,7 @@ class SQLBaseStore(object):
|
|||
)
|
||||
|
||||
txn.execute(sql, values)
|
||||
return self.cursor_to_dict(txn)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
|
@ -726,7 +720,8 @@ class SQLBaseStore(object):
|
|||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||
@staticmethod
|
||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
|
@ -743,7 +738,8 @@ class SQLBaseStore(object):
|
|||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||
@staticmethod
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
|
@ -784,7 +780,8 @@ class SQLBaseStore(object):
|
|||
raise StoreError(500, "more than one row matched")
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
|
|
|
@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
|
|||
return
|
||||
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
start = self.min_token - 1
|
||||
self.min_token -= len(events_and_contexts) + 1
|
||||
stream_orderings = range(start, self.min_token, -1)
|
||||
start = self.min_stream_token - 1
|
||||
self.min_stream_token -= len(events_and_contexts) + 1
|
||||
stream_orderings = range(start, self.min_stream_token, -1)
|
||||
|
||||
@contextmanager
|
||||
def stream_ordering_manager():
|
||||
|
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
|
|||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
self.min_stream_token -= 1
|
||||
stream_ordering = self.min_stream_token
|
||||
|
||||
if stream_ordering is None:
|
||||
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
||||
|
|
|
@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore):
|
|||
def __init__(self, hs):
|
||||
super(ReceiptsStore, self).__init__(hs)
|
||||
|
||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
||||
self._receipts_stream_cache = _RoomStreamChangeCache(
|
||||
self._receipts_id_gen.get_max_token(None)
|
||||
)
|
||||
|
||||
@cached(num_args=2)
|
||||
def get_receipts_for_room(self, room_id, receipt_type):
|
||||
|
@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object):
|
|||
may have changed since that key. If the key is too old then the cache
|
||||
will simply return all rooms.
|
||||
"""
|
||||
def __init__(self, size_of_cache=10000):
|
||||
def __init__(self, current_key, size_of_cache=10000):
|
||||
self._size_of_cache = size_of_cache
|
||||
self._room_to_key = {}
|
||||
self._cache = sorteddict()
|
||||
self._earliest_key = None
|
||||
self._earliest_key = current_key
|
||||
self.name = "ReceiptsRoomChangeCache"
|
||||
caches_by_name[self.name] = self._cache
|
||||
|
||||
|
|
|
@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore):
|
|||
rows = txn.fetchall()
|
||||
return rows[0][0] if rows else 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_min_token(self):
|
||||
row = yield self._execute(
|
||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
||||
)
|
||||
|
||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
||||
self.min_token = min(self.min_token, -1)
|
||||
|
||||
logger.debug("min_token is: %s", self.min_token)
|
||||
|
||||
defer.returnValue(self.min_token)
|
||||
|
||||
@staticmethod
|
||||
def _set_before_and_after(events, rows):
|
||||
for event, row in zip(events, rows):
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from twisted.internet import defer
|
||||
from .util.id_generators import StreamIdGenerator
|
||||
|
||||
import ujson as json
|
||||
import logging
|
||||
|
@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TagsStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(TagsStore, self).__init__(hs)
|
||||
|
||||
self._account_data_id_gen = StreamIdGenerator(
|
||||
"account_data_max_stream_id", "stream_id"
|
||||
)
|
||||
|
||||
def get_max_account_data_stream_id(self):
|
||||
"""Get the current max stream id for the private user data stream
|
||||
|
|
|
@ -72,28 +72,24 @@ class StreamIdGenerator(object):
|
|||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
def __init__(self, table, column):
|
||||
def __init__(self, db_conn, table, column):
|
||||
self.table = table
|
||||
self.column = column
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._current_max = None
|
||||
cur = db_conn.cursor()
|
||||
self._current_max = self._get_or_compute_current_max(cur)
|
||||
cur.close()
|
||||
|
||||
self._unfinished_ids = deque()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_next(self, store):
|
||||
"""
|
||||
Usage:
|
||||
with yield stream_id_gen.get_next as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
if not self._current_max:
|
||||
yield store.runInteraction(
|
||||
"_compute_current_max",
|
||||
self._get_or_compute_current_max,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._current_max += 1
|
||||
next_id = self._current_max
|
||||
|
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
|
|||
with self._lock:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
defer.returnValue(manager())
|
||||
return manager()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_next_mult(self, store, n):
|
||||
"""
|
||||
Usage:
|
||||
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
if not self._current_max:
|
||||
yield store.runInteraction(
|
||||
"_compute_current_max",
|
||||
self._get_or_compute_current_max,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
||||
self._current_max += n
|
||||
|
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
|
|||
for next_id in next_ids:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
defer.returnValue(manager())
|
||||
return manager()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_token(self, store):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
"""
|
||||
if not self._current_max:
|
||||
yield store.runInteraction(
|
||||
"_compute_current_max",
|
||||
self._get_or_compute_current_max,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
defer.returnValue(self._unfinished_ids[0] - 1)
|
||||
return self._unfinished_ids[0] - 1
|
||||
|
||||
defer.returnValue(self._current_max)
|
||||
return self._current_max
|
||||
|
||||
def _get_or_compute_current_max(self, txn):
|
||||
with self._lock:
|
||||
|
|
|
@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
f2 = self._write_config(suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
|
||||
ApplicationServiceStore(hs)
|
||||
|
||||
|
@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
f2 = self._write_config(id="id", suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||
f2 = self._write_config(as_token="as_token", suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
hs = yield setup_test_homeserver(config=config, datastore=Mock())
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
|
|
@ -18,7 +18,6 @@ from tests import unittest
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage.registration import RegistrationStore
|
||||
from synapse.util import stringutils
|
||||
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
hs = yield setup_test_homeserver()
|
||||
self.db_pool = hs.get_db_pool()
|
||||
|
||||
self.store = RegistrationStore(hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.user_id = "@my-user:test"
|
||||
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
|
||||
|
|
|
@ -60,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
name, db_pool=db_pool, config=config,
|
||||
version_string="Synapse/tests",
|
||||
database_engine=create_engine("sqlite3"),
|
||||
get_db_conn=db_pool.get_db_conn,
|
||||
**kargs
|
||||
)
|
||||
hs.setup()
|
||||
else:
|
||||
hs = HomeServer(
|
||||
name, db_pool=None, datastore=datastore, config=config,
|
||||
|
@ -280,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
|
|||
lambda conn: prepare_database(conn, engine)
|
||||
)
|
||||
|
||||
def get_db_conn(self):
|
||||
conn = self.connect()
|
||||
engine = create_engine("sqlite3")
|
||||
prepare_database(conn, engine)
|
||||
return conn
|
||||
|
||||
|
||||
class MemoryDataStore(object):
|
||||
|
||||
|
|
Loading…
Reference in a new issue