Merge device list replication streams (#14833)

This commit is contained in:
Erik Johnston 2023-01-17 09:29:58 +00:00 committed by GitHub
parent db5145a31d
commit 2b084c5b71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 72 additions and 38 deletions

View file

@ -1 +1 @@
Merge tag and normal account data replication streams. Merge the two account data and the two device list replication streams.

1
changelog.d/14833.misc Normal file
View file

@ -0,0 +1 @@
Merge the two account data and the two device list replication streams.

View file

@ -92,12 +92,13 @@ process, for example:
## Changes to the account data replication streams ## Changes to the account data replication streams
Synapse has changed the format of the account data replication streams (between Synapse has changed the format of the account data and devices replication
workers). This is a forwards- and backwards-incompatible change: v1.75 workers streams (between workers). This is a forwards- and backwards-incompatible
cannot process account data replicated by v1.76 workers, and vice versa. change: v1.75 workers cannot process account data replicated by v1.76 workers,
and vice versa.
Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data
replication will resume as normal. and device replication will resume as normal.
# Upgrading to v1.74.0 # Upgrading to v1.74.0

View file

@ -187,7 +187,7 @@ class ReplicationDataHandler:
elif stream_name == DeviceListsStream.NAME: elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set() all_room_ids: Set[str] = set()
for row in rows: for row in rows:
if row.entity.startswith("@"): if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity) room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids) all_room_ids.update(room_ids)
self.notifier.on_new_event( self.notifier.on_new_event(
@ -422,7 +422,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")} hosts = {
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
for host in hosts: for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False) self.federation_sender.send_device_messages(host, immediate=False)

View file

@ -37,7 +37,6 @@ from synapse.replication.tcp.streams._base import (
Stream, Stream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
UserSignatureStream,
) )
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream from synapse.replication.tcp.streams.federation import FederationStream
@ -62,7 +61,6 @@ STREAMS_MAP = {
ToDeviceStream, ToDeviceStream,
FederationStream, FederationStream,
AccountDataStream, AccountDataStream,
UserSignatureStream,
UnPartialStatedRoomStream, UnPartialStatedRoomStream,
UnPartialStatedEventStream, UnPartialStatedEventStream,
) )
@ -82,7 +80,6 @@ __all__ = [
"DeviceListsStream", "DeviceListsStream",
"ToDeviceStream", "ToDeviceStream",
"AccountDataStream", "AccountDataStream",
"UserSignatureStream",
"UnPartialStatedRoomStream", "UnPartialStatedRoomStream",
"UnPartialStatedEventStream", "UnPartialStatedEventStream",
] ]

View file

@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow: class DeviceListsStreamRow:
entity: str entity: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool
NAME = "device_lists" NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token), current_token_without_instance(self.store.get_device_stream_token),
store.get_all_device_list_changes_for_remotes, self._update_function,
) )
async def _update_function(
self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
(
device_updates,
devices_to_token,
devices_limited,
) = await self.store.get_all_device_list_changes_for_remotes(
instance_name, from_token, current_token, target_row_count
)
(
signatures_updates,
signatures_to_token,
signatures_limited,
) = await self.store.get_all_user_signature_changes_for_remotes(
instance_name, from_token, current_token, target_row_count
)
upper_limit_token = current_token
if devices_limited:
upper_limit_token = min(upper_limit_token, devices_to_token)
if signatures_limited:
upper_limit_token = min(upper_limit_token, signatures_to_token)
device_updates = [
(stream_id, (entity, False))
for stream_id, (entity,) in device_updates
if stream_id <= upper_limit_token
]
signatures_updates = [
(stream_id, (entity, True))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
updates = list(
heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
)
return updates, upper_limit_token, devices_limited or signatures_limited
class ToDeviceStream(Stream): class ToDeviceStream(Stream):
"""New to_device messages for a client""" """New to_device messages for a client"""
@ -583,22 +632,3 @@ class AccountDataStream(Stream):
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
) )
return updates, to_token, limited return updates, to_token, limited
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserSignatureStreamRow:
user_id: str
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
store.get_all_user_signature_changes_for_remotes,
)

View file

@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._invalidate_caches_for_devices(token, rows) self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position( def process_replication_position(
@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token) self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token) super().process_replication_position(stream_name, instance_name, token)
def _invalidate_caches_for_devices( def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None: ) -> None:
for row in rows: for row in rows:
if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
continue
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.