0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-26 14:38:18 +02:00

Add logging on startup/shutdown (#8448)

This is so we can tell what is going on when things are taking a while to start up.

The main change here is to ensure that transactions that are created during startup get correctly logged like normal transactions.
This commit is contained in:
Erik Johnston 2020-10-02 15:20:45 +01:00 committed by GitHub
parent ec10bdd32b
commit e3debf9682
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 152 additions and 113 deletions

1
changelog.d/8448.misc Normal file
View file

@ -0,0 +1 @@
Add SQL logging on queries that happen during startup.

View file

@ -489,7 +489,7 @@ class Porter(object):
hs = MockHomeserver(self.hs_config) hs = MockHomeserver(self.hs_config)
with make_conn(db_config, engine) as db_conn: with make_conn(db_config, engine, "portdb") as db_conn:
engine.check_database( engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version db_conn, allow_outdated_version=allow_outdated_version
) )

View file

@ -272,6 +272,11 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs.get_datastore().db_pool.start_profiling() hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start() hs.get_pusherpool().start()
# Log when we start the shut down process.
hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", logger.info, "Shutting down..."
)
setup_sentry(hs) setup_sentry(hs)
setup_sdnotify(hs) setup_sdnotify(hs)

View file

@ -32,6 +32,7 @@ from typing import (
overload, overload,
) )
import attr
from prometheus_client import Histogram from prometheus_client import Histogram
from typing_extensions import Literal from typing_extensions import Literal
@ -90,13 +91,17 @@ def make_pool(
return adbapi.ConnectionPool( return adbapi.ConnectionPool(
db_config.config["name"], db_config.config["name"],
cp_reactor=reactor, cp_reactor=reactor,
cp_openfun=engine.on_new_connection, cp_openfun=lambda conn: engine.on_new_connection(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
),
**db_config.config.get("args", {}) **db_config.config.get("args", {})
) )
def make_conn( def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
default_txn_name: str,
) -> Connection: ) -> Connection:
"""Make a new connection to the database and return it. """Make a new connection to the database and return it.
@ -109,11 +114,60 @@ def make_conn(
for k, v in db_config.config.get("args", {}).items() for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_") if not k.startswith("cp_")
} }
db_conn = engine.module.connect(**db_params) native_db_conn = engine.module.connect(**db_params)
db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
engine.on_new_connection(db_conn) engine.on_new_connection(db_conn)
return db_conn return db_conn
@attr.s(slots=True)
class LoggingDatabaseConnection:
"""A wrapper around a database connection that returns `LoggingTransaction`
as its cursor class.
This is mainly used on startup to ensure that queries get logged correctly
"""
conn = attr.ib(type=Connection)
engine = attr.ib(type=BaseDatabaseEngine)
default_txn_name = attr.ib(type=str)
def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
return LoggingTransaction(
self.conn.cursor(),
name=txn_name,
database_engine=self.engine,
after_callbacks=after_callbacks,
exception_callbacks=exception_callbacks,
)
def close(self) -> None:
self.conn.close()
def commit(self) -> None:
self.conn.commit()
def rollback(self, *args, **kwargs) -> None:
self.conn.rollback(*args, **kwargs)
def __enter__(self) -> "Connection":
self.conn.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
def __getattr__(self, name):
return getattr(self.conn, name)
# The type of entry which goes on our after_callbacks and exception_callbacks lists. # The type of entry which goes on our after_callbacks and exception_callbacks lists.
# #
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@ -247,6 +301,12 @@ class LoggingTransaction:
def close(self) -> None: def close(self) -> None:
self.txn.close() self.txn.close()
def __enter__(self) -> "LoggingTransaction":
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
class PerformanceCounters: class PerformanceCounters:
def __init__(self): def __init__(self):
@ -395,7 +455,7 @@ class DatabasePool:
def new_transaction( def new_transaction(
self, self,
conn: Connection, conn: LoggingDatabaseConnection,
desc: str, desc: str,
after_callbacks: List[_CallbackListEntry], after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry],
@ -418,12 +478,10 @@ class DatabasePool:
i = 0 i = 0
N = 5 N = 5
while True: while True:
cursor = LoggingTransaction( cursor = conn.cursor(
conn.cursor(), txn_name=name,
name, after_callbacks=after_callbacks,
self.engine, exception_callbacks=exception_callbacks,
after_callbacks,
exception_callbacks,
) )
try: try:
r = func(cursor, *args, **kwargs) r = func(cursor, *args, **kwargs)
@ -584,7 +642,10 @@ class DatabasePool:
logger.debug("Reconnecting closed database connection") logger.debug("Reconnecting closed database connection")
conn.reconnect() conn.reconnect()
return func(conn, *args, **kwargs) db_conn = LoggingDatabaseConnection(
conn, self.engine, "runWithConnection"
)
return func(db_conn, *args, **kwargs)
return await make_deferred_yieldable( return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs) self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@ -1621,7 +1682,7 @@ class DatabasePool:
def get_cache_dict( def get_cache_dict(
self, self,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
table: str, table: str,
entity_column: str, entity_column: str,
stream_column: str, stream_column: str,
@ -1642,9 +1703,7 @@ class DatabasePool:
"limit": limit, "limit": limit,
} }
sql = self.engine.convert_param_style(sql) txn = db_conn.cursor(txn_name="get_cache_dict")
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
cache = {row[0]: int(row[1]) for row in txn} cache = {row[0]: int(row[1]) for row in txn}

View file

@ -46,7 +46,7 @@ class Databases:
db_name = database_config.name db_name = database_config.name
engine = create_engine(database_config.config) engine = create_engine(database_config.config)
with make_conn(database_config, engine) as db_conn: with make_conn(database_config, engine, "startup") as db_conn:
logger.info("[database config %r]: Checking database server", db_name) logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn) engine.check_database(db_conn)

View file

@ -284,7 +284,6 @@ class DataStore(
" last_user_sync_ts, status_msg, currently_active FROM presence_stream" " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?" " WHERE state != ?"
) )
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,)) txn.execute(sql, (PresenceState.OFFLINE,))

View file

@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union
import attr import attr
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -74,11 +74,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.stream_ordering_month_ago = None self.stream_ordering_month_ago = None
self.stream_ordering_day_ago = None self.stream_ordering_day_ago = None
cur = LoggingTransaction( cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
)
self._find_stream_orderings_for_times_txn(cur) self._find_stream_orderings_for_times_txn(cur)
cur.close() cur.close()

View file

@ -214,7 +214,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
self._mau_stats_only = hs.config.mau_stats_only self._mau_stats_only = hs.config.mau_stats_only
# Do not add more reserved users than the total allowable number # Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
self.db_pool.new_transaction( self.db_pool.new_transaction(
db_conn, db_conn,
"initialise_mau_threepids", "initialise_mau_threepids",

View file

@ -21,12 +21,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import ( from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
LoggingTransaction,
SQLBaseStore,
db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
@ -60,10 +55,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# background update still running? # background update still running?
self._current_state_events_membership_up_to_date = False self._current_state_events_membership_up_to_date = False
txn = LoggingTransaction( txn = db_conn.cursor(
db_conn.cursor(), txn_name="_check_safe_current_state_events_membership_updated"
name="_check_safe_current_state_events_membership_updated",
database_engine=self.database_engine,
) )
self._check_safe_current_state_events_membership_updated_txn(txn) self._check_safe_current_state_events_membership_updated_txn(txn)
txn.close() txn.close()

View file

@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row[8] = bytes(row[8]).decode("utf-8") row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8")
cur.execute( cur.execute(
database_engine.convert_param_style( """
""" INSERT into pushers2 (
INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind,
id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name,
app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success,
pushkey, ts, lang, data, last_token, last_success, failing_since
failing_since ) values (%s)
) values (%s)""" """
% (",".join(["?" for _ in range(len(row))])) % (",".join(["?" for _ in range(len(row))])),
),
row, row,
) )
count += 1 count += 1

View file

@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)" " VALUES (?, ?)"
) )
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json)) cur.execute(sql, ("event_search", progress_json))

View file

@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)" " VALUES (?, ?)"
) )
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json)) cur.execute(sql, ("event_origin_server_ts", progress_json))

View file

@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks: for chunk in user_chunks:
cur.execute( cur.execute(
database_engine.convert_param_style( "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
"UPDATE users SET appservice_id = ? WHERE name IN (%s)" % (",".join("?" for _ in chunk),),
% (",".join("?" for _ in chunk),)
),
[as_id] + chunk, [as_id] + chunk,
) )

View file

@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row = list(row) row = list(row)
row[12] = token_to_stream_ordering(row[12]) row[12] = token_to_stream_ordering(row[12])
cur.execute( cur.execute(
database_engine.convert_param_style( """
""" INSERT into pushers2 (
INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind,
id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name,
app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success,
pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since
failing_since ) values (%s)
) values (%s)""" """
% (",".join(["?" for _ in range(len(row))])) % (",".join(["?" for _ in range(len(row))])),
),
row, row,
) )
count += 1 count += 1

View file

@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)" " VALUES (?, ?)"
) )
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search_order", progress_json)) cur.execute(sql, ("event_search_order", progress_json))

View file

@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)" " VALUES (?, ?)"
) )
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_fields_sender_url", progress_json)) cur.execute(sql, ("event_fields_sender_url", progress_json))

View file

@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute( cur.execute(
database_engine.convert_param_style( "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
"UPDATE remote_media_cache SET last_access_ts = ?"
),
(int(time.time() * 1000),),
) )

View file

@ -1,6 +1,8 @@
import logging import logging
from io import StringIO
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import execute_statements_from_stream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
select_clause, select_clause,
) )
if isinstance(database_engine, PostgresEngine): execute_statements_from_stream(cur, StringIO(sql))
cur.execute(sql)
else:
cur.executescript(sql)

View file

@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id) INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ? WHERE type = 'm.room.member' AND state_key LIKE ?
""" """
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("%:" + config.server_name,)) cur.execute(sql, ("%:" + config.server_name,))
cur.execute( cur.execute(

View file

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import imp import imp
import logging import logging
import os import os
@ -24,9 +23,10 @@ from typing import Optional, TextIO
import attr import attr
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Cursor
from synapse.types import Collection from synapse.types import Collection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
def prepare_database( def prepare_database(
db_conn: Connection, db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig], config: Optional[HomeServerConfig],
databases: Collection[str] = ["main", "state"], databases: Collection[str] = ["main", "state"],
@ -89,7 +89,7 @@ def prepare_database(
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor(txn_name="prepare_database")
# sqlite does not automatically start transactions for DDL / SELECT statements, # sqlite does not automatically start transactions for DDL / SELECT statements,
# so we start one before running anything. This ensures that any upgrades # so we start one before running anything. This ensures that any upgrades
@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases):
executescript(cur, entry.absolute_path) executescript(cur, entry.absolute_path)
cur.execute( cur.execute(
database_engine.convert_param_style( "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
"INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(max_current_ver, False), (max_current_ver, False),
) )
@ -486,17 +484,13 @@ def _upgrade_existing_database(
# Mark as done. # Mark as done.
cur.execute( cur.execute(
database_engine.convert_param_style( "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
"INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path), (v, relative_path),
) )
cur.execute("DELETE FROM schema_version") cur.execute("DELETE FROM schema_version")
cur.execute( cur.execute(
database_engine.convert_param_style( "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
"INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True), (v, True),
) )
@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
schemas to be applied schemas to be applied
""" """
cur.execute( cur.execute(
database_engine.convert_param_style( "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
),
(modname,),
) )
applied_deltas = {d for d, in cur} applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams: for (name, stream) in names_and_streams:
@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done. # Mark as done.
cur.execute( cur.execute(
database_engine.convert_param_style( "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
"INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name), (modname, name),
) )
@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine):
if current_version: if current_version:
txn.execute( txn.execute(
database_engine.convert_param_style( "SELECT file FROM applied_schema_deltas WHERE version >= ?",
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,), (current_version,),
) )
applied_deltas = [d for d, in txn] applied_deltas = [d for d, in txn]

View file

@ -61,3 +61,9 @@ class Connection(Protocol):
def rollback(self, *args, **kwargs) -> None: def rollback(self, *args, **kwargs) -> None:
... ...
def __enter__(self) -> "Connection":
...
def __exit__(self, exc_type, exc_value, traceback) -> bool:
...

View file

@ -54,7 +54,7 @@ def _load_current_id(db_conn, table, column, step=1):
""" """
# debug logging for https://github.com/matrix-org/synapse/issues/7968 # debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column) logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor() cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else: else:
@ -269,7 +269,7 @@ class MultiWriterIdGenerator:
def _load_current_ids( def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str self, db_conn, table: str, instance_column: str, id_column: str
): ):
cur = db_conn.cursor() cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream. # Load the current positions of all writers for the stream.
if self._writers: if self._writers:
@ -283,15 +283,12 @@ class MultiWriterIdGenerator:
stream_name = ? stream_name = ?
AND instance_name != ALL(?) AND instance_name != ALL(?)
""" """
sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (self._stream_name, self._writers)) cur.execute(sql, (self._stream_name, self._writers))
sql = """ sql = """
SELECT instance_name, stream_id FROM stream_positions SELECT instance_name, stream_id FROM stream_positions
WHERE stream_name = ? WHERE stream_name = ?
""" """
sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (self._stream_name,)) cur.execute(sql, (self._stream_name,))
self._current_positions = { self._current_positions = {
@ -340,7 +337,6 @@ class MultiWriterIdGenerator:
"instance": instance_column, "instance": instance_column,
"cmp": "<=" if self._positive else ">=", "cmp": "<=" if self._positive else ">=",
} }
sql = self._db.engine.convert_param_style(sql)
cur.execute(sql, (min_stream_id * self._return_factor,)) cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id self._persisted_upto_position = min_stream_id

View file

@ -17,6 +17,7 @@ import logging
import threading import threading
from typing import Callable, List, Optional from typing import Callable, List, Optional
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import ( from synapse.storage.engines import (
BaseDatabaseEngine, BaseDatabaseEngine,
IncorrectDatabaseSetup, IncorrectDatabaseSetup,
@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def check_consistency( def check_consistency(
self, db_conn: Connection, table: str, id_column: str, positive: bool = True self,
db_conn: LoggingDatabaseConnection,
table: str,
id_column: str,
positive: bool = True,
): ):
"""Should be called during start up to test that the current value of """Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table. the sequence is greater than or equal to the maximum ID in the table.
@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
return [i for (i,) in txn] return [i for (i,) in txn]
def check_consistency( def check_consistency(
self, db_conn: Connection, table: str, id_column: str, positive: bool = True self,
db_conn: LoggingDatabaseConnection,
table: str,
id_column: str,
positive: bool = True,
): ):
txn = db_conn.cursor() txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table. # First we get the current max ID from the table.
table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {

View file

@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
# must be done after inserts # must be done after inserts
database = hs.get_datastores().databases[0] database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore( self.store = ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs database, make_conn(database._database_config, database.engine, "test"), hs
) )
def tearDown(self): def tearDown(self):
@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
db_config = hs.config.get_single_database() db_config = hs.config.get_single_database()
self.store = TestTransactionStore( self.store = TestTransactionStore(
database, make_conn(db_config, self.engine), hs database, make_conn(db_config, self.engine, "test"), hs
) )
def _add_service(self, url, as_token, id): def _add_service(self, url, as_token, id):
@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
database = hs.get_datastores().databases[0] database = hs.get_datastores().databases[0]
ApplicationServiceStore( ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs database, make_conn(database._database_config, database.engine, "test"), hs
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0] database = hs.get_datastores().databases[0]
ApplicationServiceStore( ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs database,
make_conn(database._database_config, database.engine, "test"),
hs,
) )
e = cm.exception e = cm.exception
@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0] database = hs.get_datastores().databases[0]
ApplicationServiceStore( ApplicationServiceStore(
database, make_conn(database._database_config, database.engine), hs database,
make_conn(database._database_config, database.engine, "test"),
hs,
) )
e = cm.exception e = cm.exception

View file

@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -88,6 +89,7 @@ def setupdb():
host=POSTGRES_HOST, host=POSTGRES_HOST,
password=POSTGRES_PASSWORD, password=POSTGRES_PASSWORD,
) )
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None) prepare_database(db_conn, db_engine, None)
db_conn.close() db_conn.close()