forked from MirrorHub/synapse
Refactor to ensure we call check_consistency (#9470)
The idea here is to stop people forgetting to call `check_consistency`. Folks can still just pass in `None` to the new args in `build_sequence_generator`, but hopefully they won't.
This commit is contained in:
parent
713145d3de
commit
0b5c967813
8 changed files with 72 additions and 28 deletions
1
changelog.d/9470.bugfix
Normal file
1
changelog.d/9470.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix missing startup checks for the consistency of certain PostgreSQL sequences.
|
|
@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
|
|
||||||
# python 3 does not have a maximum int value
|
# python 3 does not have a maximum int value
|
||||||
|
@ -381,7 +380,10 @@ class DatabasePool:
|
||||||
_TXN_ID = 0
|
_TXN_ID = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
self,
|
||||||
|
hs,
|
||||||
|
database_config: DatabaseConnectionConfig,
|
||||||
|
engine: BaseDatabaseEngine,
|
||||||
):
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
@ -420,16 +422,6 @@ class DatabasePool:
|
||||||
self._check_safe_to_upsert,
|
self._check_safe_to_upsert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We define this sequence here so that it can be referenced from both
|
|
||||||
# the DataStore and PersistEventStore.
|
|
||||||
def get_chain_id_txn(txn):
|
|
||||||
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
|
||||||
return txn.fetchone()[0]
|
|
||||||
|
|
||||||
self.event_chain_id_gen = build_sequence_generator(
|
|
||||||
engine, get_chain_id_txn, "event_auth_chain_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Is the database pool currently running"""
|
"""Is the database pool currently running"""
|
||||||
return self._db_pool.running
|
return self._db_pool.running
|
||||||
|
|
|
@ -44,6 +44,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
from synapse.storage.databases.main.search import SearchEntry
|
from synapse.storage.databases.main.search import SearchEntry
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.storage.util.sequence import SequenceGenerator
|
||||||
from synapse.types import StateMap, get_domain_from_id
|
from synapse.types import StateMap, get_domain_from_id
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.iterutils import batch_iter, sorted_topologically
|
from synapse.util.iterutils import batch_iter, sorted_topologically
|
||||||
|
@ -114,12 +115,6 @@ class PersistEventsStore:
|
||||||
) # type: MultiWriterIdGenerator
|
) # type: MultiWriterIdGenerator
|
||||||
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
|
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
|
||||||
|
|
||||||
# The consistency of this cannot be checked when the ID generator is
|
|
||||||
# created since the database might not yet be up-to-date.
|
|
||||||
self.db_pool.event_chain_id_gen.check_consistency(
|
|
||||||
db_conn, "event_auth_chains", "chain_id" # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
# This should only exist on instances that are configured to write
|
# This should only exist on instances that are configured to write
|
||||||
assert (
|
assert (
|
||||||
hs.get_instance_name() in hs.config.worker.writers.events
|
hs.get_instance_name() in hs.config.worker.writers.events
|
||||||
|
@ -485,6 +480,7 @@ class PersistEventsStore:
|
||||||
self._add_chain_cover_index(
|
self._add_chain_cover_index(
|
||||||
txn,
|
txn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
self.store.event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
@ -495,6 +491,7 @@ class PersistEventsStore:
|
||||||
cls,
|
cls,
|
||||||
txn,
|
txn,
|
||||||
db_pool: DatabasePool,
|
db_pool: DatabasePool,
|
||||||
|
event_chain_id_gen: SequenceGenerator,
|
||||||
event_to_room_id: Dict[str, str],
|
event_to_room_id: Dict[str, str],
|
||||||
event_to_types: Dict[str, Tuple[str, str]],
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
event_to_auth_chain: Dict[str, List[str]],
|
event_to_auth_chain: Dict[str, List[str]],
|
||||||
|
@ -641,6 +638,7 @@ class PersistEventsStore:
|
||||||
new_chain_tuples = cls._allocate_chain_ids(
|
new_chain_tuples = cls._allocate_chain_ids(
|
||||||
txn,
|
txn,
|
||||||
db_pool,
|
db_pool,
|
||||||
|
event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
@ -779,6 +777,7 @@ class PersistEventsStore:
|
||||||
def _allocate_chain_ids(
|
def _allocate_chain_ids(
|
||||||
txn,
|
txn,
|
||||||
db_pool: DatabasePool,
|
db_pool: DatabasePool,
|
||||||
|
event_chain_id_gen: SequenceGenerator,
|
||||||
event_to_room_id: Dict[str, str],
|
event_to_room_id: Dict[str, str],
|
||||||
event_to_types: Dict[str, Tuple[str, str]],
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
event_to_auth_chain: Dict[str, List[str]],
|
event_to_auth_chain: Dict[str, List[str]],
|
||||||
|
@ -891,7 +890,7 @@ class PersistEventsStore:
|
||||||
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
||||||
|
|
||||||
# Generate new chain IDs for all unallocated chain IDs.
|
# Generate new chain IDs for all unallocated chain IDs.
|
||||||
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
|
newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
|
||||||
txn, len(unallocated_chain_ids)
|
txn, len(unallocated_chain_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
PersistEventsStore._add_chain_cover_index(
|
PersistEventsStore._add_chain_cover_index(
|
||||||
txn,
|
txn,
|
||||||
self.db_pool,
|
self.db_pool,
|
||||||
|
self.event_chain_id_gen,
|
||||||
event_to_room_id,
|
event_to_room_id,
|
||||||
event_to_types,
|
event_to_types,
|
||||||
event_to_auth_chain,
|
event_to_auth_chain,
|
||||||
|
|
|
@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import Collection, JsonDict, get_domain_from_id
|
from synapse.types import Collection, JsonDict, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
self._event_fetch_list = []
|
self._event_fetch_list = []
|
||||||
self._event_fetch_ongoing = 0
|
self._event_fetch_ongoing = 0
|
||||||
|
|
||||||
|
# We define this sequence here so that it can be referenced from both
|
||||||
|
# the DataStore and PersistEventStore.
|
||||||
|
def get_chain_id_txn(txn):
|
||||||
|
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
||||||
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
|
self.event_chain_id_gen = build_sequence_generator(
|
||||||
|
db_conn,
|
||||||
|
database.engine,
|
||||||
|
get_chain_id_txn,
|
||||||
|
"event_auth_chain_id",
|
||||||
|
table="event_auth_chains",
|
||||||
|
id_column="chain_id",
|
||||||
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == EventsStream.NAME:
|
if stream_name == EventsStream.NAME:
|
||||||
self._stream_id_gen.advance(instance_name, token)
|
self._stream_id_gen.advance(instance_name, token)
|
||||||
|
|
|
@ -23,7 +23,7 @@ import attr
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.databases.main.stats import StatsStore
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Connection, Cursor
|
from synapse.storage.types import Connection, Cursor
|
||||||
|
@ -70,7 +70,12 @@ class TokenLookupResult:
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
# call `find_max_generated_user_id_localpart` each time, which is
|
# call `find_max_generated_user_id_localpart` each time, which is
|
||||||
# expensive if there are many entries.
|
# expensive if there are many entries.
|
||||||
self._user_id_seq = build_sequence_generator(
|
self._user_id_seq = build_sequence_generator(
|
||||||
|
db_conn,
|
||||||
database.engine,
|
database.engine,
|
||||||
find_max_generated_user_id_localpart,
|
find_max_generated_user_id_localpart,
|
||||||
"user_id_seq",
|
"user_id_seq",
|
||||||
|
table=None,
|
||||||
|
id_column=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._account_validity = hs.config.account_validity
|
self._account_validity = hs.config.account_validity
|
||||||
|
@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
|
|
|
@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
return txn.fetchone()[0]
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
self._state_group_seq_gen = build_sequence_generator(
|
self._state_group_seq_gen = build_sequence_generator(
|
||||||
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
|
db_conn,
|
||||||
)
|
self.database_engine,
|
||||||
self._state_group_seq_gen.check_consistency(
|
get_max_state_group_txn,
|
||||||
db_conn, table="state_groups", id_column="id"
|
"state_group_id_seq",
|
||||||
|
table="state_groups",
|
||||||
|
id_column="id",
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=10000, iterable=True)
|
@cached(max_entries=10000, iterable=True)
|
||||||
|
|
|
@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
|
|
||||||
|
|
||||||
def build_sequence_generator(
|
def build_sequence_generator(
|
||||||
|
db_conn: "LoggingDatabaseConnection",
|
||||||
database_engine: BaseDatabaseEngine,
|
database_engine: BaseDatabaseEngine,
|
||||||
get_first_callback: GetFirstCallbackType,
|
get_first_callback: GetFirstCallbackType,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
|
table: Optional[str],
|
||||||
|
id_column: Optional[str],
|
||||||
|
stream_name: Optional[str] = None,
|
||||||
|
positive: bool = True,
|
||||||
) -> SequenceGenerator:
|
) -> SequenceGenerator:
|
||||||
"""Get the best impl of SequenceGenerator available
|
"""Get the best impl of SequenceGenerator available
|
||||||
|
|
||||||
|
@ -265,8 +270,23 @@ def build_sequence_generator(
|
||||||
get_first_callback: a callback which gets the next sequence ID. Used if
|
get_first_callback: a callback which gets the next sequence ID. Used if
|
||||||
we're on sqlite.
|
we're on sqlite.
|
||||||
sequence_name: the name of a postgres sequence to use.
|
sequence_name: the name of a postgres sequence to use.
|
||||||
|
table, id_column, stream_name, positive: If set then `check_consistency`
|
||||||
|
is called on the created sequence. See docstring for
|
||||||
|
`check_consistency` details.
|
||||||
"""
|
"""
|
||||||
if isinstance(database_engine, PostgresEngine):
|
if isinstance(database_engine, PostgresEngine):
|
||||||
return PostgresSequenceGenerator(sequence_name)
|
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
|
||||||
else:
|
else:
|
||||||
return LocalSequenceGenerator(get_first_callback)
|
seq = LocalSequenceGenerator(get_first_callback)
|
||||||
|
|
||||||
|
if table:
|
||||||
|
assert id_column
|
||||||
|
seq.check_consistency(
|
||||||
|
db_conn=db_conn,
|
||||||
|
table=table,
|
||||||
|
id_column=id_column,
|
||||||
|
stream_name=stream_name,
|
||||||
|
positive=positive,
|
||||||
|
)
|
||||||
|
|
||||||
|
return seq
|
||||||
|
|
Loading…
Reference in a new issue