Add a Homeserver.setup method.

This is for setting up dependencies that require work on startup. This
is useful for the DataStore that wants to read a bunch from the database
before initiliazing.
This commit is contained in:
Erik Johnston 2016-01-26 15:51:06 +00:00
parent 9959d9ece8
commit 87f9477b10
9 changed files with 121 additions and 116 deletions

View file

@ -254,6 +254,17 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def get_db_conn(self):
db_conn = self.database_engine.module.connect(
**{
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
self.database_engine.on_new_connection(db_conn)
return db_conn
def quit_with_error(error_string):
message_lines = error_string.split("\n")
@ -390,13 +401,7 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = database_engine.module.connect(
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
@ -411,13 +416,17 @@ def setup(config_options):
logger.info("Database prepared in %s.", config.database_config['name'])
hs.setup()
hs.start_listening()
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
def start():
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
reactor.callWhenRunning(start)
return hs

View file

@ -21,6 +21,7 @@
# Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -28,7 +29,7 @@ from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.storage import get_datastore
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.streams.events import EventSources
@ -40,6 +41,11 @@ from synapse.api.filtering import Filtering
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
import logging
logger = logging.getLogger(__name__)
class HomeServer(object):
"""A basic homeserver object without lazy component builders.
@ -102,10 +108,19 @@ class HomeServer(object):
self.hostname = hostname
self._building = {}
self.clock = Clock()
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def setup(self):
logger.info("Setting up.")
self.datastore = get_datastore(self)
logger.info("Finished setting up.")
def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
@ -116,15 +131,9 @@ class HomeServer(object):
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
def build_clock(self):
return Clock()
def build_replication_layer(self):
return initialize_http_replication(self)
def build_datastore(self):
return DataStore(self)
def build_handlers(self):
return Handlers(self)
@ -135,10 +144,9 @@ class HomeServer(object):
return Auth(self)
def build_http_client_context_factory(self):
config = self.get_config()
return (
InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
@ -157,15 +165,9 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
def build_distributor(self):
return Distributor()
def build_event_sources(self):
return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def build_keyring(self):
return Keyring(self)

View file

@ -46,6 +46,9 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
import logging
@ -58,6 +61,22 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120*1000
def get_datastore(hs):
logger.info("getting called!")
conn = hs.get_db_conn()
try:
cur = conn.cursor()
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
min_token = min(min_token, -1)
return DataStore(conn, hs, min_token)
finally:
conn.close()
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
def __init__(self, db_conn, hs, min_stream_token):
self.hs = hs
self.min_token_deferred = self._get_min_token()
self.min_token = None
self.min_stream_token = min_stream_token
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
super(DataStore, self).__init__(hs)
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())

View file

@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@ -345,7 +333,8 @@ class SQLBaseStore(object):
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
@log_function
def _simple_insert_txn(self, txn, table, values):
@staticmethod
def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values):
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
return
@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
@classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
ret = self._simple_select_onecol_txn(
ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
@classmethod
def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -627,7 +620,7 @@ class SQLBaseStore(object):
)
txn.execute(sql)
return self.cursor_to_dict(txn)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
@ -662,7 +655,8 @@ class SQLBaseStore(object):
defer.returnValue(results)
def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
@classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -699,7 +693,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, values)
return self.cursor_to_dict(txn)
return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@ -726,7 +720,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
@staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@ -743,7 +738,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
@staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@ -784,7 +780,8 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
def _simple_delete_txn(self, txn, table, keyvalues):
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)

View file

@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
start = self.min_token - 1
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
start = self.min_stream_token - 1
self.min_stream_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
self.min_token -= 1
stream_ordering = self.min_token
self.min_stream_token -= 1
stream_ordering = self.min_stream_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)

View file

@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache()
self._receipts_stream_cache = _RoomStreamChangeCache(
self._receipts_id_gen.get_max_token(None)
)
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object):
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=10000):
def __init__(self, current_key, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
self._earliest_key = current_key
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache

View file

@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
@staticmethod
def _set_before_and_after(events, rows):
for event, row in zip(events, rows):

View file

@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._account_data_id_gen = StreamIdGenerator(
"account_data_max_stream_id", "stream_id"
)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream

View file

@ -72,28 +72,24 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
def __init__(self, table, column):
def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock()
self._current_max = None
cur = db_conn.cursor()
self._current_max = self._get_or_compute_current_max(cur)
cur.close()
self._unfinished_ids = deque()
@defer.inlineCallbacks
def get_next(self, store):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
self._current_max += 1
next_id = self._current_max
@ -108,21 +104,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
return manager()
@defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@ -139,24 +128,17 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
return manager()
@defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1)
return self._unfinished_ids[0] - 1
defer.returnValue(self._current_max)
return self._current_max
def _get_or_compute_current_max(self, txn):
with self._lock: