0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 23:11:34 +01:00

Merge branch 'release-v1.109' into develop

This commit is contained in:
Quentin Gliech 2024-06-17 15:51:16 +02:00
commit e88332b5f4
No known key found for this signature in database
GPG key ID: 22D62B84552719FC
8 changed files with 193 additions and 216 deletions

View file

@ -479,6 +479,9 @@ jobs:
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
env: env:
# If this is a pull request to a release branch, use that branch as default branch for sytest, else use develop
# This works because the release script always create a branch on the sytest repo with the same name as the release branch
SYTEST_DEFAULT_BRANCH: ${{ startsWith(github.base_ref, 'release-') && github.base_ref || 'develop' }}
SYTEST_BRANCH: ${{ github.head_ref }} SYTEST_BRANCH: ${{ github.head_ref }}
POSTGRES: ${{ matrix.job.postgres && 1}} POSTGRES: ${{ matrix.job.postgres && 1}}
MULTI_POSTGRES: ${{ (matrix.job.postgres == 'multi-postgres') || '' }} MULTI_POSTGRES: ${{ (matrix.job.postgres == 'multi-postgres') || '' }}

View file

@ -1,3 +1,16 @@
# Synapse 1.109.0rc3 (2024-06-17)
### Bugfixes
- When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL. ([\#17305](https://github.com/element-hq/synapse/issues/17305), [\#17309](https://github.com/element-hq/synapse/issues/17309))
### Internal Changes
- Use the release branch for sytest in release-branch PRs. ([\#17306](https://github.com/element-hq/synapse/issues/17306))
# Synapse 1.109.0rc2 (2024-06-11) # Synapse 1.109.0rc2 (2024-06-11)
### Bugfixes ### Bugfixes

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.109.0~rc3) stable; urgency=medium
* New synapse release 1.109.0rc3.
-- Synapse Packaging team <packages@matrix.org> Mon, 17 Jun 2024 12:05:24 +0000
matrix-synapse-py3 (1.109.0~rc2) stable; urgency=medium matrix-synapse-py3 (1.109.0~rc2) stable; urgency=medium
* New synapse release 1.109.0rc2. * New synapse release 1.109.0rc2.

View file

@ -255,13 +255,3 @@ however extreme care must be taken to avoid database corruption.
Note that the above may fail with an error about duplicate rows if corruption Note that the above may fail with an error about duplicate rows if corruption
has already occurred, and such duplicate rows will need to be manually removed. has already occurred, and such duplicate rows will need to be manually removed.
### Fixing inconsistent sequences error
Synapse uses Postgres sequences to generate IDs for various tables. A sequence
and associated table can get out of sync if, for example, Synapse has been
downgraded and then upgraded again.
To fix the issue shut down Synapse (including any and all workers) and run the
SQL command included in the error message. Once done Synapse should start
successfully.

View file

@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
[tool.poetry] [tool.poetry]
name = "matrix-synapse" name = "matrix-synapse"
version = "1.109.0rc2" version = "1.109.0rc3"
description = "Homeserver for the Matrix decentralised comms protocol" description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"] authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "AGPL-3.0-or-later" license = "AGPL-3.0-or-later"

View file

@ -276,9 +276,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress. # no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id self._max_position_of_local_instance = self._max_seen_allocated_stream_id
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._sequence_gen = build_sequence_generator( self._sequence_gen = build_sequence_generator(
db_conn=db_conn, db_conn=db_conn,
database_engine=db.engine, database_engine=db.engine,
@ -303,6 +300,13 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive=positive, positive=positive,
) )
# This goes and fills out the above state from the database.
# This may read on the PostgreSQL sequence, and
# SequenceGenerator.check_consistency might have fixed up the sequence, which
# means the SequenceGenerator needs to be setup before we read the value from
# the sequence.
self._load_current_ids(db_conn, tables, sequence_name)
self._max_seen_allocated_stream_id = max( self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
@ -327,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]], tables: List[Tuple[str, str, str]],
sequence_name: str,
) -> None: ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids") cur = db_conn.cursor(txn_name="_load_current_ids")
@ -360,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance in self._writers if instance in self._writers
} }
# If we're a writer, we can assume we're at the end of the stream
# Usually, we would get that from the stream_positions, but in some cases,
# like if we rolled back Synapse, the stream_positions table might not be up to
# date. If we're using Postgres for the sequences, we can just use the current
# sequence value as our own position.
if self._instance_name in self._writers:
if isinstance(self._db.engine, PostgresEngine):
cur.execute(f"SELECT last_value FROM {sequence_name}")
row = cur.fetchone()
assert row is not None
self._current_positions[self._instance_name] = row[0]
# We set the `_persisted_upto_position` to be the minimum of all current # We set the `_persisted_upto_position` to be the minimum of all current
# positions. If empty we use the max stream ID from the DB table. # positions. If empty we use the max stream ID from the DB table.
min_stream_id = min(self._current_positions.values(), default=None) min_stream_id = min(self._current_positions.values(), default=None)

View file

@ -36,21 +36,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_INCONSISTENT_SEQUENCE_ERROR = """
Postgres sequence '%(seq)s' is inconsistent with associated
table '%(table)s'. This can happen if Synapse has been downgraded and
then upgraded again, or due to a bad migration.
To fix this error, shut down Synapse (including any and all workers)
and run the following SQL:
SELECT setval('%(seq)s', (
%(max_id_sql)s
));
See docs/postgres.md for more information.
"""
_INCONSISTENT_STREAM_ERROR = """ _INCONSISTENT_STREAM_ERROR = """
Postgres sequence '%(seq)s' is inconsistent with associated stream position Postgres sequence '%(seq)s' is inconsistent with associated stream position
of '%(stream_name)s' in the 'stream_positions' table. of '%(stream_name)s' in the 'stream_positions' table.
@ -169,25 +154,33 @@ class PostgresSequenceGenerator(SequenceGenerator):
if row: if row:
max_in_stream_positions = row[0] max_in_stream_positions = row[0]
txn.close()
# If `is_called` is False then `last_value` is actually the value that # If `is_called` is False then `last_value` is actually the value that
# will be generated next, so we decrement to get the true "last value". # will be generated next, so we decrement to get the true "last value".
if not is_called: if not is_called:
last_value -= 1 last_value -= 1
if max_stream_id > last_value: if max_stream_id > last_value:
# The sequence is lagging behind the tables. This is probably due to
# rolling back to a version before the sequence was used and then
# forwards again. We resolve this by setting the sequence to the
# right value.
logger.warning( logger.warning(
"Postgres sequence %s is behind table %s: %d < %d", "Postgres sequence %s is behind table %s: %d < %d. Updating sequence.",
self._sequence_name, self._sequence_name,
table, table,
last_value, last_value,
max_stream_id, max_stream_id,
) )
raise IncorrectDatabaseSetup(
_INCONSISTENT_SEQUENCE_ERROR sql = f"""
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} SELECT setval('{self._sequence_name}', GREATEST(
) (SELECT last_value FROM {self._sequence_name}),
({table_sql})
));
"""
txn.execute(sql)
txn.close()
# If we have values in the stream positions table then they have to be # If we have values in the stream positions table then they have to be
# less than or equal to `last_value` # less than or equal to `last_value`

View file

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import List, Optional from typing import Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -28,7 +28,6 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import ( from synapse.storage.util.sequence import (
@ -43,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
class MultiWriterIdGeneratorBase(HomeserverTestCase): class MultiWriterIdGeneratorBase(HomeserverTestCase):
positive: bool = True
tables: List[str] = ["foobar"]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.instances: Dict[str, MultiWriterIdGenerator] = {}
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@ -58,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
for table in self.tables:
txn.execute( txn.execute(
""" """
CREATE TABLE foobar ( CREATE TABLE %s (
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL, instance_name TEXT NOT NULL,
data TEXT data TEXT
); );
""" """
% (table,)
) )
def _create_id_generator( def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None self,
instance_name: str = "master",
writers: Optional[List[str]] = None,
) -> MultiWriterIdGenerator: ) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
@ -78,35 +85,92 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
notifier=self.hs.get_replication_notifier(), notifier=self.hs.get_replication_notifier(),
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[(table, "instance_name", "stream_id") for table in self.tables],
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers or ["master"], writers=writers or ["master"],
positive=self.positive,
) )
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) self.instances[instance_name] = self.get_success_or_raise(
self.db_pool.runWithConnection(_create)
)
return self.instances[instance_name]
def _insert_rows(self, instance_name: str, number: int) -> None: def _replicate(self, instance_name: str) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled """Similate a replication event for the given instance."""
from the postgres sequence.
""" writer = self.instances[instance_name]
token = writer.get_current_token_for_writer(instance_name)
for generator in self.instances.values():
if writer != generator:
generator.advance(instance_name, token)
def _replicate_all(self) -> None:
"""Similate a replication event for all instances."""
for instance_name in self.instances:
self._replicate(instance_name)
def _insert_row(
self, instance_name: str, stream_id: int, table: Optional[str] = None
) -> None:
"""Insert one row as the given instance with given stream_id."""
if table is None:
table = self.tables[0]
factor = 1 if self.positive else -1
def _insert(txn: LoggingTransaction) -> None: def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute( txn.execute(
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)", "INSERT INTO %s VALUES (?, ?)" % (table,),
( (
next_val, stream_id,
instance_name, instance_name,
), ),
) )
txn.execute( txn.execute(
""" """
INSERT INTO stream_positions VALUES ('test_stream', ?, ?) INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""", """,
(instance_name, next_val, next_val), (instance_name, stream_id * factor, stream_id * factor),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def _insert_rows(
self,
instance_name: str,
number: int,
table: Optional[str] = None,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
if table is None:
table = self.tables[0]
factor = 1 if self.positive else -1
def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute(
"INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
% (table,),
(next_val, instance_name),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, next_val * factor, next_val * factor),
) )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@ -354,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
id_gen = self._create_id_generator("first", writers=["first", "second"]) id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) # When the writer is created, it assumes its own position is the current head of
# the sequence
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@ -376,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
correctly. correctly.
""" """
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@ -399,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual( self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7} first_id_gen.get_positions(), {"first": 3, "second": 7}
) )
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async()) self.get_success(_get_next_async())
@ -433,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
""" """
# Insert some rows for two out of three of the ID gens. # Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator( first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"] "first", writers=["first", "second", "third"]
) )
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator( second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"] "second", writers=["first", "second", "third"]
) )
@ -445,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"third", writers=["first", "second", "third"] "third", writers=["first", "second", "third"]
) )
self._replicate_all()
self.assertEqual( self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
) )
@ -525,7 +598,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self) -> None: def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges.""" """Test that we correct the sequence if the table and sequence diverges."""
# Prefill with some rows # Prefill with some rows
self._insert_row_with_id("master", 3) self._insert_row_with_id("master", 3)
@ -536,17 +609,24 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.get_success(self.db_pool.runInteraction("_insert", _insert)) self.get_success(self.db_pool.runInteraction("_insert", _insert))
# Creating the ID gen should error # Creating the ID gen should now fix the inconsistency
with self.assertRaises(IncorrectDatabaseSetup): id_gen = self._create_id_generator()
self._create_id_generator("first")
async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 27)
self.get_success(_get_next_async())
def test_minimal_local_token(self) -> None: def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3) self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
@ -558,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
token when there are no writes. token when there are no writes.
""" """
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator( first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"] "first", writers=["first", "second", "third"]
) )
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator( second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"] "second", writers=["first", "second", "third"]
) )
self._replicate_all()
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7) self.assertEqual(second_id_gen.get_current_token(), 7)
@ -605,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(second_id_gen.get_current_token(), 7) self.assertEqual(second_id_gen.get_current_token(), 7)
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: positive = False
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers or ["master"],
positive=False,
)
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, -stream_id, -stream_id),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self) -> None: def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
@ -712,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async() -> None: async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id: async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id) self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id) self._replicate("first")
self.get_success(_get_next_async()) self.get_success(_get_next_async())
@ -724,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async2() -> None: async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id: async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id) self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id) self._replicate("second")
self.get_success(_get_next_async2()) self.get_success(_get_next_async2())
@ -734,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: tables = ["foobar1", "foobar2"]
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self) -> None: def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after """Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table. the position in `stream_positions` table.
""" """
self._insert_rows("foobar1", "first", 3) self._insert_rows("first", 3, table="foobar1")
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 3, table="foobar2")
self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)