mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 00:43:46 +01:00
Set our own stream position from the current sequence value on startup (#17309)
This commit is contained in:
parent
12d7303707
commit
f983a77ab0
3 changed files with 147 additions and 178 deletions
1
changelog.d/17309.misc
Normal file
1
changelog.d/17309.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL.
|
|
@ -276,9 +276,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
# no active writes in progress.
|
# no active writes in progress.
|
||||||
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
|
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
|
||||||
|
|
||||||
# This goes and fills out the above state from the database.
|
|
||||||
self._load_current_ids(db_conn, tables)
|
|
||||||
|
|
||||||
self._sequence_gen = build_sequence_generator(
|
self._sequence_gen = build_sequence_generator(
|
||||||
db_conn=db_conn,
|
db_conn=db_conn,
|
||||||
database_engine=db.engine,
|
database_engine=db.engine,
|
||||||
|
@ -303,6 +300,13 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
positive=positive,
|
positive=positive,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This goes and fills out the above state from the database.
|
||||||
|
# This may read on the PostgreSQL sequence, and
|
||||||
|
# SequenceGenerator.check_consistency might have fixed up the sequence, which
|
||||||
|
# means the SequenceGenerator needs to be setup before we read the value from
|
||||||
|
# the sequence.
|
||||||
|
self._load_current_ids(db_conn, tables, sequence_name)
|
||||||
|
|
||||||
self._max_seen_allocated_stream_id = max(
|
self._max_seen_allocated_stream_id = max(
|
||||||
self._current_positions.values(), default=1
|
self._current_positions.values(), default=1
|
||||||
)
|
)
|
||||||
|
@ -327,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
self,
|
self,
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
tables: List[Tuple[str, str, str]],
|
tables: List[Tuple[str, str, str]],
|
||||||
|
sequence_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
cur = db_conn.cursor(txn_name="_load_current_ids")
|
cur = db_conn.cursor(txn_name="_load_current_ids")
|
||||||
|
|
||||||
|
@ -360,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
if instance in self._writers
|
if instance in self._writers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# If we're a writer, we can assume we're at the end of the stream
|
||||||
|
# Usually, we would get that from the stream_positions, but in some cases,
|
||||||
|
# like if we rolled back Synapse, the stream_positions table might not be up to
|
||||||
|
# date. If we're using Postgres for the sequences, we can just use the current
|
||||||
|
# sequence value as our own position.
|
||||||
|
if self._instance_name in self._writers:
|
||||||
|
if isinstance(self._db.engine, PostgresEngine):
|
||||||
|
cur.execute(f"SELECT last_value FROM {sequence_name}")
|
||||||
|
row = cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
self._current_positions[self._instance_name] = row[0]
|
||||||
|
|
||||||
# We set the `_persisted_upto_position` to be the minimum of all current
|
# We set the `_persisted_upto_position` to be the minimum of all current
|
||||||
# positions. If empty we use the max stream ID from the DB table.
|
# positions. If empty we use the max stream ID from the DB table.
|
||||||
min_stream_id = min(self._current_positions.values(), default=None)
|
min_stream_id = min(self._current_positions.values(), default=None)
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -42,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||||
|
positive: bool = True
|
||||||
|
tables: List[str] = ["foobar"]
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
self.db_pool: DatabasePool = self.store.db_pool
|
||||||
|
self.instances: Dict[str, MultiWriterIdGenerator] = {}
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||||
|
|
||||||
|
@ -57,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||||
if USE_POSTGRES_FOR_TESTS:
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||||
|
|
||||||
txn.execute(
|
for table in self.tables:
|
||||||
"""
|
txn.execute(
|
||||||
CREATE TABLE foobar (
|
"""
|
||||||
stream_id BIGINT NOT NULL,
|
CREATE TABLE %s (
|
||||||
instance_name TEXT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
data TEXT
|
instance_name TEXT NOT NULL,
|
||||||
);
|
data TEXT
|
||||||
"""
|
);
|
||||||
)
|
"""
|
||||||
|
% (table,)
|
||||||
|
)
|
||||||
|
|
||||||
def _create_id_generator(
|
def _create_id_generator(
|
||||||
self, instance_name: str = "master", writers: Optional[List[str]] = None
|
self,
|
||||||
|
instance_name: str = "master",
|
||||||
|
writers: Optional[List[str]] = None,
|
||||||
) -> MultiWriterIdGenerator:
|
) -> MultiWriterIdGenerator:
|
||||||
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
|
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
|
||||||
return MultiWriterIdGenerator(
|
return MultiWriterIdGenerator(
|
||||||
|
@ -77,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||||
notifier=self.hs.get_replication_notifier(),
|
notifier=self.hs.get_replication_notifier(),
|
||||||
stream_name="test_stream",
|
stream_name="test_stream",
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
tables=[(table, "instance_name", "stream_id") for table in self.tables],
|
||||||
sequence_name="foobar_seq",
|
sequence_name="foobar_seq",
|
||||||
writers=writers or ["master"],
|
writers=writers or ["master"],
|
||||||
|
positive=self.positive,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
self.instances[instance_name] = self.get_success_or_raise(
|
||||||
|
self.db_pool.runWithConnection(_create)
|
||||||
|
)
|
||||||
|
return self.instances[instance_name]
|
||||||
|
|
||||||
def _insert_rows(self, instance_name: str, number: int) -> None:
|
def _replicate(self, instance_name: str) -> None:
|
||||||
|
"""Similate a replication event for the given instance."""
|
||||||
|
|
||||||
|
writer = self.instances[instance_name]
|
||||||
|
token = writer.get_current_token_for_writer(instance_name)
|
||||||
|
for generator in self.instances.values():
|
||||||
|
if writer != generator:
|
||||||
|
generator.advance(instance_name, token)
|
||||||
|
|
||||||
|
def _replicate_all(self) -> None:
|
||||||
|
"""Similate a replication event for all instances."""
|
||||||
|
|
||||||
|
for instance_name in self.instances:
|
||||||
|
self._replicate(instance_name)
|
||||||
|
|
||||||
|
def _insert_row(
|
||||||
|
self, instance_name: str, stream_id: int, table: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Insert one row as the given instance with given stream_id."""
|
||||||
|
|
||||||
|
if table is None:
|
||||||
|
table = self.tables[0]
|
||||||
|
|
||||||
|
factor = 1 if self.positive else -1
|
||||||
|
|
||||||
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(
|
||||||
|
"INSERT INTO %s VALUES (?, ?)" % (table,),
|
||||||
|
(
|
||||||
|
stream_id,
|
||||||
|
instance_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||||
|
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||||
|
""",
|
||||||
|
(instance_name, stream_id * factor, stream_id * factor),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
||||||
|
|
||||||
|
def _insert_rows(
|
||||||
|
self,
|
||||||
|
instance_name: str,
|
||||||
|
number: int,
|
||||||
|
table: Optional[str] = None,
|
||||||
|
update_stream_table: bool = True,
|
||||||
|
) -> None:
|
||||||
"""Insert N rows as the given instance, inserting with stream IDs pulled
|
"""Insert N rows as the given instance, inserting with stream IDs pulled
|
||||||
from the postgres sequence.
|
from the postgres sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if table is None:
|
||||||
|
table = self.tables[0]
|
||||||
|
|
||||||
|
factor = 1 if self.positive else -1
|
||||||
|
|
||||||
def _insert(txn: LoggingTransaction) -> None:
|
def _insert(txn: LoggingTransaction) -> None:
|
||||||
for _ in range(number):
|
for _ in range(number):
|
||||||
next_val = self.seq_gen.get_next_id_txn(txn)
|
next_val = self.seq_gen.get_next_id_txn(txn)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
|
"INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
|
||||||
(
|
% (table,),
|
||||||
next_val,
|
(next_val, instance_name),
|
||||||
instance_name,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(
|
if update_stream_table:
|
||||||
"""
|
txn.execute(
|
||||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
"""
|
||||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
||||||
""",
|
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
||||||
(instance_name, next_val, next_val),
|
""",
|
||||||
)
|
(instance_name, next_val * factor, next_val * factor),
|
||||||
|
)
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
||||||
|
|
||||||
|
@ -353,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
|
|
||||||
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
|
# When the writer is created, it assumes its own position is the current head of
|
||||||
|
# the sequence
|
||||||
|
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
|
||||||
|
|
||||||
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
|
||||||
|
|
||||||
|
@ -375,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
correctly.
|
correctly.
|
||||||
"""
|
"""
|
||||||
self._insert_rows("first", 3)
|
self._insert_rows("first", 3)
|
||||||
self._insert_rows("second", 4)
|
|
||||||
|
|
||||||
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
|
self._insert_rows("second", 4)
|
||||||
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
||||||
|
|
||||||
|
self._replicate_all()
|
||||||
|
|
||||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
||||||
|
@ -398,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
first_id_gen.get_positions(), {"first": 3, "second": 7}
|
first_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
second_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||||
|
)
|
||||||
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
self.get_success(_get_next_async())
|
self.get_success(_get_next_async())
|
||||||
|
@ -432,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
"""
|
"""
|
||||||
# Insert some rows for two out of three of the ID gens.
|
# Insert some rows for two out of three of the ID gens.
|
||||||
self._insert_rows("first", 3)
|
self._insert_rows("first", 3)
|
||||||
self._insert_rows("second", 4)
|
|
||||||
|
|
||||||
first_id_gen = self._create_id_generator(
|
first_id_gen = self._create_id_generator(
|
||||||
"first", writers=["first", "second", "third"]
|
"first", writers=["first", "second", "third"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._insert_rows("second", 4)
|
||||||
second_id_gen = self._create_id_generator(
|
second_id_gen = self._create_id_generator(
|
||||||
"second", writers=["first", "second", "third"]
|
"second", writers=["first", "second", "third"]
|
||||||
)
|
)
|
||||||
|
@ -444,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
"third", writers=["first", "second", "third"]
|
"third", writers=["first", "second", "third"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._replicate_all()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
|
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
|
||||||
)
|
)
|
||||||
|
@ -546,11 +620,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
|
|
||||||
def test_minimal_local_token(self) -> None:
|
def test_minimal_local_token(self) -> None:
|
||||||
self._insert_rows("first", 3)
|
self._insert_rows("first", 3)
|
||||||
self._insert_rows("second", 4)
|
|
||||||
|
|
||||||
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
|
self._insert_rows("second", 4)
|
||||||
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
||||||
|
|
||||||
|
self._replicate_all()
|
||||||
|
|
||||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||||
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
|
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
|
||||||
|
|
||||||
|
@ -562,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
token when there are no writes.
|
token when there are no writes.
|
||||||
"""
|
"""
|
||||||
self._insert_rows("first", 3)
|
self._insert_rows("first", 3)
|
||||||
self._insert_rows("second", 4)
|
|
||||||
|
|
||||||
first_id_gen = self._create_id_generator(
|
first_id_gen = self._create_id_generator(
|
||||||
"first", writers=["first", "second", "third"]
|
"first", writers=["first", "second", "third"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._insert_rows("second", 4)
|
||||||
second_id_gen = self._create_id_generator(
|
second_id_gen = self._create_id_generator(
|
||||||
"second", writers=["first", "second", "third"]
|
"second", writers=["first", "second", "third"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._replicate_all()
|
||||||
|
|
||||||
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
|
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
|
||||||
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
|
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
|
||||||
self.assertEqual(second_id_gen.get_current_token(), 7)
|
self.assertEqual(second_id_gen.get_current_token(), 7)
|
||||||
|
@ -609,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
self.assertEqual(second_id_gen.get_current_token(), 7)
|
self.assertEqual(second_id_gen.get_current_token(), 7)
|
||||||
|
|
||||||
|
|
||||||
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
|
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
|
||||||
|
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
positive = False
|
||||||
self.store = hs.get_datastores().main
|
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
|
||||||
|
|
||||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE foobar (
|
|
||||||
stream_id BIGINT NOT NULL,
|
|
||||||
instance_name TEXT NOT NULL,
|
|
||||||
data TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_id_generator(
|
|
||||||
self, instance_name: str = "master", writers: Optional[List[str]] = None
|
|
||||||
) -> MultiWriterIdGenerator:
|
|
||||||
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
|
|
||||||
return MultiWriterIdGenerator(
|
|
||||||
conn,
|
|
||||||
self.db_pool,
|
|
||||||
notifier=self.hs.get_replication_notifier(),
|
|
||||||
stream_name="test_stream",
|
|
||||||
instance_name=instance_name,
|
|
||||||
tables=[("foobar", "instance_name", "stream_id")],
|
|
||||||
sequence_name="foobar_seq",
|
|
||||||
writers=writers or ["master"],
|
|
||||||
positive=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_success(self.db_pool.runWithConnection(_create))
|
|
||||||
|
|
||||||
def _insert_row(self, instance_name: str, stream_id: int) -> None:
|
|
||||||
"""Insert one row as the given instance with given stream_id."""
|
|
||||||
|
|
||||||
def _insert(txn: LoggingTransaction) -> None:
|
|
||||||
txn.execute(
|
|
||||||
"INSERT INTO foobar VALUES (?, ?)",
|
|
||||||
(
|
|
||||||
stream_id,
|
|
||||||
instance_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
|
|
||||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
|
|
||||||
""",
|
|
||||||
(instance_name, -stream_id, -stream_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
|
||||||
|
|
||||||
def test_single_instance(self) -> None:
|
def test_single_instance(self) -> None:
|
||||||
"""Test that reads and writes from a single process are handled
|
"""Test that reads and writes from a single process are handled
|
||||||
|
@ -716,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
async def _get_next_async() -> None:
|
async def _get_next_async() -> None:
|
||||||
async with id_gen_1.get_next() as stream_id:
|
async with id_gen_1.get_next() as stream_id:
|
||||||
self._insert_row("first", stream_id)
|
self._insert_row("first", stream_id)
|
||||||
id_gen_2.advance("first", stream_id)
|
self._replicate("first")
|
||||||
|
|
||||||
self.get_success(_get_next_async())
|
self.get_success(_get_next_async())
|
||||||
|
|
||||||
|
@ -728,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
async def _get_next_async2() -> None:
|
async def _get_next_async2() -> None:
|
||||||
async with id_gen_2.get_next() as stream_id:
|
async with id_gen_2.get_next() as stream_id:
|
||||||
self._insert_row("second", stream_id)
|
self._insert_row("second", stream_id)
|
||||||
id_gen_1.advance("second", stream_id)
|
self._replicate("second")
|
||||||
|
|
||||||
self.get_success(_get_next_async2())
|
self.get_success(_get_next_async2())
|
||||||
|
|
||||||
|
@ -738,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
||||||
|
|
||||||
|
|
||||||
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
|
||||||
if not USE_POSTGRES_FOR_TESTS:
|
if not USE_POSTGRES_FOR_TESTS:
|
||||||
skip = "Requires Postgres"
|
skip = "Requires Postgres"
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
tables = ["foobar1", "foobar2"]
|
||||||
self.store = hs.get_datastores().main
|
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
|
||||||
|
|
||||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
|
||||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE foobar1 (
|
|
||||||
stream_id BIGINT NOT NULL,
|
|
||||||
instance_name TEXT NOT NULL,
|
|
||||||
data TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE foobar2 (
|
|
||||||
stream_id BIGINT NOT NULL,
|
|
||||||
instance_name TEXT NOT NULL,
|
|
||||||
data TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_id_generator(
|
|
||||||
self, instance_name: str = "master", writers: Optional[List[str]] = None
|
|
||||||
) -> MultiWriterIdGenerator:
|
|
||||||
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
|
|
||||||
return MultiWriterIdGenerator(
|
|
||||||
conn,
|
|
||||||
self.db_pool,
|
|
||||||
notifier=self.hs.get_replication_notifier(),
|
|
||||||
stream_name="test_stream",
|
|
||||||
instance_name=instance_name,
|
|
||||||
tables=[
|
|
||||||
("foobar1", "instance_name", "stream_id"),
|
|
||||||
("foobar2", "instance_name", "stream_id"),
|
|
||||||
],
|
|
||||||
sequence_name="foobar_seq",
|
|
||||||
writers=writers or ["master"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
|
||||||
|
|
||||||
def _insert_rows(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
instance_name: str,
|
|
||||||
number: int,
|
|
||||||
update_stream_table: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""Insert N rows as the given instance, inserting with stream IDs pulled
|
|
||||||
from the postgres sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _insert(txn: LoggingTransaction) -> None:
|
|
||||||
for _ in range(number):
|
|
||||||
txn.execute(
|
|
||||||
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
|
|
||||||
(instance_name,),
|
|
||||||
)
|
|
||||||
if update_stream_table:
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
|
|
||||||
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
|
|
||||||
""",
|
|
||||||
(instance_name,),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
|
|
||||||
|
|
||||||
def test_load_existing_stream(self) -> None:
|
def test_load_existing_stream(self) -> None:
|
||||||
"""Test creating ID gens with multiple tables that have rows from after
|
"""Test creating ID gens with multiple tables that have rows from after
|
||||||
the position in `stream_positions` table.
|
the position in `stream_positions` table.
|
||||||
"""
|
"""
|
||||||
self._insert_rows("foobar1", "first", 3)
|
self._insert_rows("first", 3, table="foobar1")
|
||||||
self._insert_rows("foobar2", "second", 3)
|
|
||||||
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
|
|
||||||
|
|
||||||
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
|
||||||
|
|
||||||
|
self._insert_rows("second", 3, table="foobar2")
|
||||||
|
self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
|
||||||
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
|
||||||
|
|
||||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
|
self._replicate_all()
|
||||||
|
|
||||||
|
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
|
||||||
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
||||||
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
Loading…
Reference in a new issue