mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 06:13:52 +01:00
Replaces all usages of StreamIdGenerator
with MultiWriterIdGenerator
(#17229)
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer.
This commit is contained in:
parent
225f378ffa
commit
d16910ca02
10 changed files with 227 additions and 363 deletions
1
changelog.d/17229.misc
Normal file
1
changelog.d/17229.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
|
|
@ -777,22 +777,74 @@ class Porter:
|
||||||
await self._setup_events_stream_seqs()
|
await self._setup_events_stream_seqs()
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"un_partial_stated_event_stream_sequence",
|
"un_partial_stated_event_stream_sequence",
|
||||||
("un_partial_stated_event_stream",),
|
[("un_partial_stated_event_stream", "stream_id")],
|
||||||
)
|
)
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
"device_inbox_sequence",
|
||||||
|
[
|
||||||
|
("device_inbox", "stream_id"),
|
||||||
|
("device_federation_outbox", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"account_data_sequence",
|
"account_data_sequence",
|
||||||
("room_account_data", "room_tags_revisions", "account_data"),
|
[
|
||||||
|
("room_account_data", "stream_id"),
|
||||||
|
("room_tags_revisions", "stream_id"),
|
||||||
|
("account_data", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"receipts_sequence",
|
||||||
|
[
|
||||||
|
("receipts_linearized", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"presence_stream_sequence",
|
||||||
|
[
|
||||||
|
("presence_stream", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
|
|
||||||
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
|
|
||||||
await self._setup_auth_chain_sequence()
|
await self._setup_auth_chain_sequence()
|
||||||
await self._setup_sequence(
|
await self._setup_sequence(
|
||||||
"application_services_txn_id_seq",
|
"application_services_txn_id_seq",
|
||||||
("application_services_txns",),
|
[
|
||||||
"txn_id",
|
(
|
||||||
|
"application_services_txns",
|
||||||
|
"txn_id",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"device_lists_sequence",
|
||||||
|
[
|
||||||
|
("device_lists_stream", "stream_id"),
|
||||||
|
("user_signature_stream", "stream_id"),
|
||||||
|
("device_lists_outbound_pokes", "stream_id"),
|
||||||
|
("device_lists_changes_in_room", "stream_id"),
|
||||||
|
("device_lists_remote_pending", "stream_id"),
|
||||||
|
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"e2e_cross_signing_keys_sequence",
|
||||||
|
[
|
||||||
|
("e2e_cross_signing_keys", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"push_rules_stream_sequence",
|
||||||
|
[
|
||||||
|
("push_rules_stream", "stream_id"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await self._setup_sequence(
|
||||||
|
"pushers_sequence",
|
||||||
|
[
|
||||||
|
("pushers", "id"),
|
||||||
|
("deleted_pushers", "stream_id"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3. Get tables.
|
# Step 3. Get tables.
|
||||||
|
@ -1101,12 +1153,11 @@ class Porter:
|
||||||
async def _setup_sequence(
|
async def _setup_sequence(
|
||||||
self,
|
self,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
stream_id_tables: Iterable[str],
|
stream_id_tables: Iterable[Tuple[str, str]],
|
||||||
column_name: str = "stream_id",
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set a sequence to the correct value."""
|
"""Set a sequence to the correct value."""
|
||||||
current_stream_ids = []
|
current_stream_ids = []
|
||||||
for stream_id_table in stream_id_tables:
|
for stream_id_table, column_name in stream_id_tables:
|
||||||
max_stream_id = cast(
|
max_stream_id = cast(
|
||||||
int,
|
int,
|
||||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
|
|
|
@ -57,10 +57,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
JsonDict,
|
JsonDict,
|
||||||
JsonMapping,
|
JsonMapping,
|
||||||
|
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||||
# class below that is used on the main process.
|
# class below that is used on the main process.
|
||||||
self._device_list_id_gen = StreamIdGenerator(
|
self._device_list_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"device_lists_stream",
|
notifier=hs.get_replication_notifier(),
|
||||||
"stream_id",
|
stream_name="device_lists_stream",
|
||||||
extra_tables=[
|
instance_name=self._instance_name,
|
||||||
("user_signature_stream", "stream_id"),
|
tables=[
|
||||||
("device_lists_outbound_pokes", "stream_id"),
|
("device_lists_stream", "instance_name", "stream_id"),
|
||||||
("device_lists_changes_in_room", "stream_id"),
|
("user_signature_stream", "instance_name", "stream_id"),
|
||||||
("device_lists_remote_pending", "stream_id"),
|
("device_lists_outbound_pokes", "instance_name", "stream_id"),
|
||||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
("device_lists_changes_in_room", "instance_name", "stream_id"),
|
||||||
|
("device_lists_remote_pending", "instance_name", "stream_id"),
|
||||||
],
|
],
|
||||||
is_writer=hs.config.worker.worker_app is None,
|
sequence_name="device_lists_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
device_list_max = self._device_list_id_gen.get_current_token()
|
device_list_max = self._device_list_id_gen.get_current_token()
|
||||||
|
@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
"from_user_id": from_user_id,
|
"from_user_id": from_user_id,
|
||||||
"user_ids": json_encoder.encode(user_ids),
|
"user_ids": json_encoder.encode(user_ids),
|
||||||
|
"instance_name": self._instance_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
"device_lists_stream_idx",
|
"device_lists_stream_idx",
|
||||||
index_name="device_lists_stream_user_id",
|
index_name="device_lists_stream_user_id",
|
||||||
|
@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
"device_lists_outbound_pokes",
|
"device_lists_outbound_pokes",
|
||||||
{
|
{
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"destination": destination,
|
"destination": destination,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
|
|
||||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
# Because we have write access, this will be a StreamIdGenerator
|
|
||||||
# (see DeviceWorkerStore.__init__)
|
|
||||||
_device_list_id_gen: AbstractStreamIdGenerator
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
|
@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
self.db_pool.simple_insert_many_txn(
|
self.db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_stream",
|
table="device_lists_stream",
|
||||||
keys=("stream_id", "user_id", "device_id"),
|
keys=("instance_name", "stream_id", "user_id", "device_id"),
|
||||||
values=[
|
values=[
|
||||||
(stream_id, user_id, device_id)
|
(self._instance_name, stream_id, user_id, device_id)
|
||||||
for stream_id, device_id in zip(stream_ids, device_ids)
|
for stream_id, device_id in zip(stream_ids, device_ids)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
values = [
|
values = [
|
||||||
(
|
(
|
||||||
destination,
|
destination,
|
||||||
|
self._instance_name,
|
||||||
next(stream_id_iterator),
|
next(stream_id_iterator),
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
|
@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
table="device_lists_outbound_pokes",
|
table="device_lists_outbound_pokes",
|
||||||
keys=(
|
keys=(
|
||||||
"destination",
|
"destination",
|
||||||
|
"instance_name",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"device_id",
|
"device_id",
|
||||||
|
@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_id,
|
device_id,
|
||||||
{
|
{
|
||||||
stream_id: destination
|
stream_id: destination
|
||||||
for (destination, stream_id, _, _, _, _, _) in values
|
for (destination, _, stream_id, _, _, _, _, _) in values
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
"device_id",
|
"device_id",
|
||||||
"room_id",
|
"room_id",
|
||||||
"stream_id",
|
"stream_id",
|
||||||
|
"instance_name",
|
||||||
"converted_to_destinations",
|
"converted_to_destinations",
|
||||||
"opentracing_context",
|
"opentracing_context",
|
||||||
),
|
),
|
||||||
|
@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_id,
|
device_id,
|
||||||
room_id,
|
room_id,
|
||||||
stream_id,
|
stream_id,
|
||||||
|
self._instance_name,
|
||||||
# We only need to calculate outbound pokes for local users
|
# We only need to calculate outbound pokes for local users
|
||||||
not self.hs.is_mine_id(user_id),
|
not self.hs.is_mine_id(user_id),
|
||||||
encoded_context,
|
encoded_context,
|
||||||
|
@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
},
|
},
|
||||||
values={"stream_id": stream_id},
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
|
},
|
||||||
desc="add_remote_device_list_to_pending",
|
desc="add_remote_device_list_to_pending",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.types import JsonDict, JsonMapping
|
from synapse.types import JsonDict, JsonMapping
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"e2e_cross_signing_keys",
|
notifier=hs.get_replication_notifier(),
|
||||||
"stream_id",
|
stream_name="e2e_cross_signing_keys",
|
||||||
|
instance_name=self._instance_name,
|
||||||
|
tables=[
|
||||||
|
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="e2e_cross_signing_keys_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_e2e_device_keys(
|
async def set_e2e_device_keys(
|
||||||
|
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
"keytype": key_type,
|
"keytype": key_type,
|
||||||
"keydata": json_encoder.encode(key),
|
"keydata": json_encoder.encode(key),
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||||
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder, unwrapFirstError
|
from synapse.util import json_encoder, unwrapFirstError
|
||||||
|
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
|
||||||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_push_rules_stream_id_gen: StreamIdGenerator
|
_push_rules_stream_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
|
||||||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||||
)
|
)
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
|
||||||
# class below that is used on the main process.
|
db_conn=db_conn,
|
||||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
db=database,
|
||||||
db_conn,
|
notifier=hs.get_replication_notifier(),
|
||||||
hs.get_replication_notifier(),
|
stream_name="push_rules_stream",
|
||||||
"push_rules_stream",
|
instance_name=self._instance_name,
|
||||||
"stream_id",
|
tables=[
|
||||||
is_writer=self._is_push_writer,
|
("push_rules_stream", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="push_rules_stream_sequence",
|
||||||
|
writers=hs.config.worker.writers.push_rules,
|
||||||
)
|
)
|
||||||
|
|
||||||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
||||||
|
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
|
||||||
raise Exception("Not a push writer")
|
raise Exception("Not a push writer")
|
||||||
|
|
||||||
values = {
|
values = {
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
"event_stream_ordering": event_stream_ordering,
|
"event_stream_ordering": event_stream_ordering,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
|
@ -40,10 +40,7 @@ from synapse.storage.database import (
|
||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
AbstractStreamIdGenerator,
|
|
||||||
StreamIdGenerator,
|
|
||||||
)
|
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
self._instance_name = hs.get_instance_name()
|
||||||
# class below that is used on the main process.
|
|
||||||
self._pushers_id_gen = StreamIdGenerator(
|
self._pushers_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn,
|
db_conn=db_conn,
|
||||||
hs.get_replication_notifier(),
|
db=database,
|
||||||
"pushers",
|
notifier=hs.get_replication_notifier(),
|
||||||
"id",
|
stream_name="pushers",
|
||||||
extra_tables=[("deleted_pushers", "stream_id")],
|
instance_name=self._instance_name,
|
||||||
is_writer=hs.config.worker.worker_app is None,
|
tables=[
|
||||||
|
("pushers", "instance_name", "id"),
|
||||||
|
("deleted_pushers", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="pushers_sequence",
|
||||||
|
writers=["master"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_pool.updates.register_background_update_handler(
|
self.db_pool.updates.register_background_update_handler(
|
||||||
|
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
||||||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
# Because we have write access, this will be a StreamIdGenerator
|
# Because we have write access, this will be a StreamIdGenerator
|
||||||
# (see PusherWorkerStore.__init__)
|
# (see PusherWorkerStore.__init__)
|
||||||
_pushers_id_gen: AbstractStreamIdGenerator
|
_pushers_id_gen: MultiWriterIdGenerator
|
||||||
|
|
||||||
async def add_pusher(
|
async def add_pusher(
|
||||||
self,
|
self,
|
||||||
|
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
"last_stream_ordering": last_stream_ordering,
|
"last_stream_ordering": last_stream_ordering,
|
||||||
"profile_tag": profile_tag,
|
"profile_tag": profile_tag,
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"enabled": enabled,
|
"enabled": enabled,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
# XXX(quenting): We're only really persisting the access token ID
|
# XXX(quenting): We're only really persisting the access token ID
|
||||||
|
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
table="deleted_pushers",
|
table="deleted_pushers",
|
||||||
values={
|
values={
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"pushkey": pushkey,
|
"pushkey": pushkey,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||||
self.db_pool.simple_insert_many_txn(
|
self.db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="deleted_pushers",
|
table="deleted_pushers",
|
||||||
keys=("stream_id", "app_id", "pushkey", "user_id"),
|
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
|
||||||
values=[
|
values=[
|
||||||
(stream_id, pusher.app_id, pusher.pushkey, user_id)
|
(
|
||||||
|
stream_id,
|
||||||
|
self._instance_name,
|
||||||
|
pusher.app_id,
|
||||||
|
pusher.pushkey,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
for stream_id, pusher in zip(stream_ids, pushers)
|
for stream_id, pusher in zip(stream_ids, pushers)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Add `instance_name` columns to stream tables to allow them to be used with
|
||||||
|
-- `MultiWriterIdGenerator`
|
||||||
|
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
|
||||||
|
|
||||||
|
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
|
||||||
|
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
|
|
@ -0,0 +1,54 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Add squences for stream tables to allow them to be used with
|
||||||
|
-- `MultiWriterIdGenerator`
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
|
||||||
|
|
||||||
|
-- We need to take the max across all the device lists tables as they share the
|
||||||
|
-- ID generator
|
||||||
|
SELECT setval('device_lists_sequence', (
|
||||||
|
SELECT GREATEST(
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
|
||||||
|
)
|
||||||
|
));
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
|
||||||
|
|
||||||
|
SELECT setval('e2e_cross_signing_keys_sequence', (
|
||||||
|
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
|
||||||
|
));
|
||||||
|
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
|
||||||
|
|
||||||
|
SELECT setval('push_rules_stream_sequence', (
|
||||||
|
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
|
||||||
|
));
|
||||||
|
|
||||||
|
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
|
||||||
|
|
||||||
|
-- We need to take the max across all the pusher tables as they share the
|
||||||
|
-- ID generator
|
||||||
|
SELECT setval('pushers_sequence', (
|
||||||
|
SELECT GREATEST(
|
||||||
|
(SELECT COALESCE(MAX(id), 1) FROM pushers),
|
||||||
|
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
|
||||||
|
)
|
||||||
|
));
|
|
@ -23,15 +23,12 @@ import abc
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
|
||||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
|
||||||
|
|
||||||
This class must only be used when the current Synapse process is the sole
|
|
||||||
writer for a stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_conn(connection): A database connection to use to fetch the
|
|
||||||
initial value of the generator from.
|
|
||||||
table(str): A database table to read the initial value of the id
|
|
||||||
generator from.
|
|
||||||
column(str): The column of the database table to read the initial
|
|
||||||
value from the id generator from.
|
|
||||||
extra_tables(list): List of pairs of database tables and columns to
|
|
||||||
use to source the initial value of the generator from. The value
|
|
||||||
with the largest magnitude is used.
|
|
||||||
step(int): which direction the stream ids grow in. +1 to grow
|
|
||||||
upwards, -1 to grow downwards.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
|
||||||
# ... persist event ...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
db_conn: LoggingDatabaseConnection,
|
|
||||||
notifier: "ReplicationNotifier",
|
|
||||||
table: str,
|
|
||||||
column: str,
|
|
||||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
|
||||||
step: int = 1,
|
|
||||||
is_writer: bool = True,
|
|
||||||
) -> None:
|
|
||||||
assert step != 0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._step: int = step
|
|
||||||
self._current: int = _load_current_id(db_conn, table, column, step)
|
|
||||||
self._is_writer = is_writer
|
|
||||||
for table, column in extra_tables:
|
|
||||||
self._current = (max if step > 0 else min)(
|
|
||||||
self._current, _load_current_id(db_conn, table, column, step)
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use this as an ordered set, as we want to efficiently append items,
|
|
||||||
# remove items and get the first item. Since we insert IDs in order, the
|
|
||||||
# insertion ordering will ensure its in the correct ordering.
|
|
||||||
#
|
|
||||||
# The key and values are the same, but we never look at the values.
|
|
||||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
|
||||||
|
|
||||||
self._notifier = notifier
|
|
||||||
|
|
||||||
def advance(self, instance_name: str, new_id: int) -> None:
|
|
||||||
# Advance should never be called on a writer instance, only over replication
|
|
||||||
if self._is_writer:
|
|
||||||
raise Exception("Replication is not supported by writer StreamIdGenerator")
|
|
||||||
|
|
||||||
self._current = (max if self._step > 0 else min)(self._current, new_id)
|
|
||||||
|
|
||||||
def get_next(self) -> AsyncContextManager[int]:
|
|
||||||
with self._lock:
|
|
||||||
self._current += self._step
|
|
||||||
next_id = self._current
|
|
||||||
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def manager() -> Generator[int, None, None]:
|
|
||||||
try:
|
|
||||||
yield next_id
|
|
||||||
finally:
|
|
||||||
with self._lock:
|
|
||||||
self._unfinished_ids.pop(next_id)
|
|
||||||
|
|
||||||
self._notifier.notify_replication()
|
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
|
||||||
with self._lock:
|
|
||||||
next_ids = range(
|
|
||||||
self._current + self._step,
|
|
||||||
self._current + self._step * (n + 1),
|
|
||||||
self._step,
|
|
||||||
)
|
|
||||||
self._current += n * self._step
|
|
||||||
|
|
||||||
for next_id in next_ids:
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def manager() -> Generator[Sequence[int], None, None]:
|
|
||||||
try:
|
|
||||||
yield next_ids
|
|
||||||
finally:
|
|
||||||
with self._lock:
|
|
||||||
for next_id in next_ids:
|
|
||||||
self._unfinished_ids.pop(next_id)
|
|
||||||
|
|
||||||
self._notifier.notify_replication()
|
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
|
||||||
"""
|
|
||||||
Retrieve the next stream ID from within a database transaction.
|
|
||||||
|
|
||||||
Clean-up functions will be called when the transaction finishes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
txn: The database transaction object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next stream ID.
|
|
||||||
"""
|
|
||||||
if not self._is_writer:
|
|
||||||
raise Exception("Tried to allocate stream ID on non-writer")
|
|
||||||
|
|
||||||
# Get the next stream ID.
|
|
||||||
with self._lock:
|
|
||||||
self._current += self._step
|
|
||||||
next_id = self._current
|
|
||||||
|
|
||||||
self._unfinished_ids[next_id] = next_id
|
|
||||||
|
|
||||||
def clear_unfinished_id(id_to_clear: int) -> None:
|
|
||||||
"""A function to mark processing this ID as finished"""
|
|
||||||
with self._lock:
|
|
||||||
self._unfinished_ids.pop(id_to_clear)
|
|
||||||
|
|
||||||
# Mark this ID as finished once the database transaction itself finishes.
|
|
||||||
txn.call_after(clear_unfinished_id, next_id)
|
|
||||||
txn.call_on_exception(clear_unfinished_id, next_id)
|
|
||||||
|
|
||||||
# Return the new ID.
|
|
||||||
return next_id
|
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
|
||||||
if not self._is_writer:
|
|
||||||
return self._current
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if self._unfinished_ids:
|
|
||||||
return next(iter(self._unfinished_ids)) - self._step
|
|
||||||
|
|
||||||
return self._current
|
|
||||||
|
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
|
||||||
return self.get_current_token()
|
|
||||||
|
|
||||||
def get_minimal_local_current_token(self) -> int:
|
|
||||||
return self.get_current_token()
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.storage.util.sequence import (
|
from synapse.storage.util.sequence import (
|
||||||
LocalSequenceGenerator,
|
LocalSequenceGenerator,
|
||||||
PostgresSequenceGenerator,
|
PostgresSequenceGenerator,
|
||||||
|
@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
|
||||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGeneratorTestCase(HomeserverTestCase):
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
|
||||||
self.store = hs.get_datastores().main
|
|
||||||
self.db_pool: DatabasePool = self.store.db_pool
|
|
||||||
|
|
||||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
|
||||||
|
|
||||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
|
||||||
txn.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE foobar (
|
|
||||||
stream_id BIGINT NOT NULL,
|
|
||||||
data TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
|
|
||||||
|
|
||||||
def _create_id_generator(self) -> StreamIdGenerator:
|
|
||||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
|
||||||
return StreamIdGenerator(
|
|
||||||
db_conn=conn,
|
|
||||||
notifier=self.hs.get_replication_notifier(),
|
|
||||||
table="foobar",
|
|
||||||
column="stream_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
|
||||||
|
|
||||||
def test_initial_value(self) -> None:
|
|
||||||
"""Check that we read the current token from the DB."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
def test_single_gen_next(self) -> None:
|
|
||||||
"""Check that we correctly increment the current token from the DB."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
async with id_gen.get_next() as next_id:
|
|
||||||
# We haven't persisted `next_id` yet; current token is still 123
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
# But we did learn what the next value is
|
|
||||||
self.assertEqual(next_id, 124)
|
|
||||||
|
|
||||||
# Once the context manager closes we assume that the `next_id` has been
|
|
||||||
# written to the DB.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_multiple_gen_nexts(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next sensibly."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request three new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
# None are persisted: current token unchanged.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Persist each in turn.
|
|
||||||
await ctx1.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 125)
|
|
||||||
await ctx3.__aexit__(None, None, None)
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next, even when their IDs
|
|
||||||
created and persisted in different orders."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request three new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
# None are persisted: current token unchanged.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Persist them in a different order, starting with 126 from ctx3.
|
|
||||||
await ctx3.__aexit__(None, None, None)
|
|
||||||
# We haven't persisted 124 from ctx1 yet---current token is still 123.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Now persist 124 from ctx1.
|
|
||||||
await ctx1.__aexit__(None, None, None)
|
|
||||||
# Current token is then 124, waiting for 125 to be persisted.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 124)
|
|
||||||
|
|
||||||
# Finally persist 125 from ctx2.
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
# Current token is then 126 (skipping over 125).
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
|
|
||||||
"""Check that we handle overlapping calls to gen_next."""
|
|
||||||
id_gen = self._create_id_generator()
|
|
||||||
|
|
||||||
async def test_gen_next() -> None:
|
|
||||||
ctx1 = id_gen.get_next()
|
|
||||||
ctx2 = id_gen.get_next()
|
|
||||||
ctx3 = id_gen.get_next()
|
|
||||||
|
|
||||||
# Request two new stream IDs.
|
|
||||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
|
||||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
|
||||||
|
|
||||||
# Persist ctx2 first.
|
|
||||||
await ctx2.__aexit__(None, None, None)
|
|
||||||
# Still waiting on ctx1's ID to be persisted.
|
|
||||||
self.assertEqual(id_gen.get_current_token(), 123)
|
|
||||||
|
|
||||||
# Now request a third stream ID. It should be 126 (the smallest ID that
|
|
||||||
# we've not yet handed out.)
|
|
||||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
|
||||||
|
|
||||||
self.get_success(test_gen_next())
|
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
|
Loading…
Reference in a new issue