Refactor to ensure we call check_consistency (#9470)

The idea here is to stop people forgetting to call `check_consistency`. Folks can still just pass in `None` to the new args in `build_sequence_generator`, but hopefully they won't.
This commit is contained in:
Erik Johnston 2021-02-24 10:13:53 +00:00 committed by GitHub
parent 713145d3de
commit 0b5c967813
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 72 additions and 28 deletions

1
changelog.d/9470.bugfix Normal file
View file

@ -0,0 +1 @@
Fix missing startup checks for the consistency of certain PostgreSQL sequences.

View file

@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection from synapse.types import Collection
# python 3 does not have a maximum int value # python 3 does not have a maximum int value
@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0 _TXN_ID = 0
def __init__( def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine self,
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
): ):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -420,16 +422,6 @@ class DatabasePool:
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
engine, get_chain_id_txn, "event_auth_chain_id"
)
def is_running(self) -> bool: def is_running(self) -> bool:
"""Is the database pool currently running""" """Is the database pool currently running"""
return self._db_pool.running return self._db_pool.running

View file

@ -44,6 +44,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.iterutils import batch_iter, sorted_topologically
@ -114,12 +115,6 @@ class PersistEventsStore:
) # type: MultiWriterIdGenerator ) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# The consistency of this cannot be checked when the ID generator is
# created since the database might not yet be up-to-date.
self.db_pool.event_chain_id_gen.check_consistency(
db_conn, "event_auth_chains", "chain_id" # type: ignore
)
# This should only exist on instances that are configured to write # This should only exist on instances that are configured to write
assert ( assert (
hs.get_instance_name() in hs.config.worker.writers.events hs.get_instance_name() in hs.config.worker.writers.events
@ -485,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index( self._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -495,6 +491,7 @@ class PersistEventsStore:
cls, cls,
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -641,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids( new_chain_tuples = cls._allocate_chain_ids(
txn, txn,
db_pool, db_pool,
event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,
@ -779,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids( def _allocate_chain_ids(
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, List[str]],
@ -891,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs. # Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids) txn, len(unallocated_chain_ids)
) )

View file

@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn,
self.db_pool, self.db_pool,
self.event_chain_id_gen,
event_to_room_id, event_to_room_id,
event_to_types, event_to_types,
event_to_auth_chain, event_to_auth_chain,

View file

@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
database.engine,
get_chain_id_txn,
"event_auth_chain_id",
table="event_auth_chains",
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token) self._stream_id_gen.advance(instance_name, token)

View file

@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries. # expensive if there are many entries.
self._user_id_seq = build_sequence_generator( self._user_id_seq = build_sequence_generator(
db_conn,
database.engine, database.engine,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
"user_id_seq", "user_id_seq",
table=None,
id_column=None,
) )
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()

View file

@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0] return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator( self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq" db_conn,
) self.database_engine,
self._state_group_seq_gen.check_consistency( get_max_state_group_txn,
db_conn, table="state_groups", id_column="id" "state_group_id_seq",
table="state_groups",
id_column="id",
) )
@cached(max_entries=10000, iterable=True) @cached(max_entries=10000, iterable=True)

View file

@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
def build_sequence_generator( def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType, get_first_callback: GetFirstCallbackType,
sequence_name: str, sequence_name: str,
table: Optional[str],
id_column: Optional[str],
stream_name: Optional[str] = None,
positive: bool = True,
) -> SequenceGenerator: ) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available """Get the best impl of SequenceGenerator available
@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite. we're on sqlite.
sequence_name: the name of a postgres sequence to use. sequence_name: the name of a postgres sequence to use.
table, id_column, stream_name, positive: If set then `check_consistency`
is called on the created sequence. See docstring for
`check_consistency` details.
""" """
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
return PostgresSequenceGenerator(sequence_name) seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else: else:
return LocalSequenceGenerator(get_first_callback) seq = LocalSequenceGenerator(get_first_callback)
if table:
assert id_column
seq.check_consistency(
db_conn=db_conn,
table=table,
id_column=id_column,
stream_name=stream_name,
positive=positive,
)
return seq