From d537be1ebd0e7ce4c84118efa400932cc6432aa9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:40:02 +0000 Subject: [PATCH] Pass Database into the data store --- synapse/server.py | 3 +- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/data_stores/__init__.py | 7 +++-- synapse/storage/database.py | 38 +++++++++++-------------- 5 files changed, 24 insertions(+), 28 deletions(-) diff --git a/synapse/server.py b/synapse/server.py index be9af7f98..2db3dab22 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -238,8 +238,7 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(datastore, conn, self) + self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) conn.commit() self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f9e7f9a71..b7637b5dc 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -41,7 +41,7 @@ class SQLBaseStore(object): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine - self.db = Database(hs) # In future this will be passed in + self.db = database self.rand = random.SystemRandom() def _invalidate_state_caches(self, room_id, members_changed): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index a9a13a265..4f97fd5ab 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -379,7 +379,7 @@ class BackgroundUpdater(object): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.db.database_engine, engines.PostgresEngine): + if isinstance(self.db.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cb184a98c..79ecc6273 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.database import Database + class DataStores(object): """The various data stores. @@ -20,7 +22,8 @@ class DataStores(object): These are low level interfaces to physical databases. """ - def __init__(self, main_store, db_conn, hs): + def __init__(self, main_store_class, db_conn, hs): # Note we pass in the main store here as workers use a different main # store. - self.main = main_store + database = Database(hs) + self.main = main_store_class(database, db_conn, hs) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6843b7e7f..ec19ae1d9 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -234,7 +234,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.database_engine = hs.database_engine + self.engine = hs.database_engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -242,10 +242,10 @@ class Database(object): # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): + if isinstance(self.engine, Sqlite3Engine): self._unsafe_to_upsert_tables.add("user_directory_search") - if self.database_engine.can_native_upsert: + if self.engine.can_native_upsert: # Check ASAP (and then later, every 1s) to see if we have finished # background updates of tables that aren't safe to update. self._clock.call_later( @@ -331,7 +331,7 @@ class Database(object): cursor = LoggingTransaction( conn.cursor(), name, - self.database_engine, + self.engine, after_callbacks, exception_callbacks, ) @@ -339,7 +339,7 @@ class Database(object): r = func(cursor, *args, **kwargs) conn.commit() return r - except self.database_engine.module.OperationalError as e: + except self.engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. logger.warning( @@ -353,20 +353,20 @@ class Database(object): i += 1 try: conn.rollback() - except self.database_engine.module.Error as e1: + except self.engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) ) continue raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): + except self.engine.module.DatabaseError as e: + if self.engine.is_deadlock(e): logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) if i < N: i += 1 try: conn.rollback() - except self.database_engine.module.Error as e1: + except self.engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", name, @@ -494,7 +494,7 @@ class Database(object): sql_scheduling_timer.observe(sched_duration_sec) context.add_database_scheduled(sched_duration_sec) - if self.database_engine.is_connection_closed(conn): + if self.engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() @@ -561,7 +561,7 @@ class Database(object): """ try: yield self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: + except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. if not or_ignore: @@ -660,7 +660,7 @@ class Database(object): lock=lock, ) return result - except self.database_engine.module.IntegrityError as e: + except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: # don't retry forever, because things other than races @@ -692,10 +692,7 @@ class Database(object): upserts return True if a new entry was created, False if an existing one was updated. """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) @@ -726,7 +723,7 @@ class Database(object): """ # We need to lock the table :(, unless we're *really* careful if lock: - self.database_engine.lock_table(txn, table) + self.engine.lock_table(txn, table) def _getwhere(key): # If the value we're passing in is None (aka NULL), we need to use @@ -828,10 +825,7 @@ class Database(object): Returns: None """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) @@ -1301,7 +1295,7 @@ class Database(object): "limit": limit, } - sql = self.database_engine.convert_param_style(sql) + sql = self.engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (int(max_value),))