0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 22:03:49 +01:00

Allow running sendToDevice on workers (#9044)

This commit is contained in:
Erik Johnston 2021-01-07 20:19:26 +00:00 committed by GitHub
parent 5e99a94502
commit b530eaa262
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 231 additions and 105 deletions

1
changelog.d/9044.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for handling and persistence of to-device messages to happen on worker processes.

View file

@ -629,6 +629,7 @@ class Porter(object):
await self._setup_state_group_id_seq() await self._setup_state_group_id_seq()
await self._setup_user_id_seq() await self._setup_user_id_seq()
await self._setup_events_stream_seqs() await self._setup_events_stream_seqs()
await self._setup_device_inbox_seq()
# Step 3. Get tables. # Step 3. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
@ -911,6 +912,32 @@ class Porter(object):
"_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
) )
async def _setup_device_inbox_seq(self):
"""Set the device inbox sequence to the correct value.
"""
curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="device_inbox",
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="device_federation_outbox",
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
)
next_id = max(curr_local_id, curr_federation_id) + 1
def r(txn):
txn.execute(
"ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
)
return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
############################################## ##############################################
# The following is simply UI stuff # The following is simply UI stuff

View file

@ -108,6 +108,7 @@ from synapse.rest.client.v2_alpha.account_data import (
) )
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
@ -520,6 +521,8 @@ class GenericWorkerServer(HomeServer):
room.register_deprecated_servlets(self, resource) room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource) InitialSyncRestServlet(self).register(resource)
SendToDeviceRestServlet(self).register(resource)
user_directory.register_servlets(self, resource) user_directory.register_servlets(self, resource)
# If presence is disabled, use the stub servlet that does # If presence is disabled, use the stub servlet that does

View file

@ -53,6 +53,9 @@ class WriterLocations:
default=["master"], type=List[str], converter=_instance_to_list_converter default=["master"], type=List[str], converter=_instance_to_list_converter
) )
typing = attr.ib(default="master", type=str) typing = attr.ib(default="master", type=str)
to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
class WorkerConfig(Config): class WorkerConfig(Config):
@ -124,7 +127,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in # Check that the configured writers for events and typing also appears in
# `instance_map`. # `instance_map`.
for stream in ("events", "typing"): for stream in ("events", "typing", "to_device"):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances: for instance in instances:
if instance != "master" and instance not in self.instance_map: if instance != "master" and instance not in self.instance_map:
@ -133,6 +136,11 @@ class WorkerConfig(Config):
% (instance, stream) % (instance, stream)
) )
if len(self.writers.to_device) != 1:
raise ConfigError(
"Must only specify one instance to handle `to_device` messages."
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not. # Whether this worker should run background tasks or not.

View file

@ -45,11 +45,25 @@ class DeviceMessageHandler:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_federation_registry().register_edu_handler( # We only need to poke the federation sender explicitly if its on the
"m.direct_to_device", self.on_direct_to_device_edu # same instance. Other federation sender instances will get notified by
) # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the to-device replication stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the to device EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.direct_to_device", hs.config.worker.writers.to_device,
)
# The handler to call when we think a user's device list might be out of # The handler to call when we think a user's device list might be out of
# sync. We do all device list resyncing on the master instance, so if # sync. We do all device list resyncing on the master instance, so if
@ -204,7 +218,8 @@ class DeviceMessageHandler:
) )
log_kv({"remote_messages": remote_messages}) log_kv({"remote_messages": remote_messages})
for destination in remote_messages.keys(): if self.federation_sender:
# Enqueue a new federation transaction to send the new for destination in remote_messages.keys():
# device messages to each remote destination. # Enqueue a new federation transaction to send the new
self.federation.send_device_messages(destination) # device messages to each remote destination.
self.federation_sender.send_device_messages(destination)

View file

@ -14,38 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): pass
super().__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token(),
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token(),
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import (
EventsStream, EventsStream,
FederationStream, FederationStream,
Stream, Stream,
ToDeviceStream,
TypingStream, TypingStream,
) )
@ -115,6 +116,14 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, ToDeviceStream):
# Only add ToDeviceStream as a source on instances in charge of
# sending to device messages.
if hs.get_instance_name() in hs.config.worker.writers.to_device:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, TypingStream): if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of # Only add TypingStream as a source on the instance in charge of
# typing. # typing.

View file

@ -127,9 +127,6 @@ class DataStore(
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator( self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
@ -189,36 +186,6 @@ class DataStore(
prefilled_cache=presence_cache_prefill, prefilled_cache=presence_cache_prefill,
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token() device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache( self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max "DeviceListStreamChangeCache", device_list_max

View file

@ -17,10 +17,14 @@ import logging
from typing import List, Tuple from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,6 +33,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
# Map of (user_id, device_id) to the last stream_id that has been # Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions. # deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache( self._last_device_delete_cache = ExpiringCache(
@ -38,6 +44,73 @@ class DeviceInboxWorkerStore(SQLBaseStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
if isinstance(database.engine, PostgresEngine):
self._can_write_to_device = (
self._instance_name in hs.config.worker.writers.to_device
)
self._device_inbox_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="to_device",
instance_name=self._instance_name,
table="device_inbox",
instance_column="instance_name",
id_column="stream_id",
sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device,
)
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
@ -290,38 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.db_pool.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
await self.db_pool.runWithConnection(reindex_txn)
await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
@trace @trace
async def add_messages_to_device_inbox( async def add_messages_to_device_inbox(
self, self,
@ -340,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
The new stream_id. The new stream_id.
""" """
assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(txn, now_ms, stream_id):
# Add the local messages directly to the local inbox. # Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn( self._add_messages_to_local_device_inbox_txn(
@ -358,6 +401,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
"stream_id": stream_id, "stream_id": stream_id,
"queued_ts": now_ms, "queued_ts": now_ms,
"messages_json": json_encoder.encode(edu), "messages_json": json_encoder.encode(edu),
"instance_name": self._instance_name,
} }
for destination, edu in remote_messages_by_destination.items() for destination, edu in remote_messages_by_destination.items()
], ],
@ -380,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
async def add_messages_from_remote_to_device_inbox( async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict self, origin: str, message_id: str, local_messages_by_user_then_device: dict
) -> int: ) -> int:
assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that # Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our # origin. This can happen if the origin doesn't receive our
@ -428,6 +474,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn( def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device self, txn, stream_id, messages_by_user_then_device
): ):
assert self._can_write_to_device
local_by_user_then_device = {} local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {} messages_json_for_user = {}
@ -481,8 +529,43 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
"device_id": device_id, "device_id": device_id,
"stream_id": stream_id, "stream_id": stream_id,
"message_json": message_json, "message_json": message_json,
"instance_name": self._instance_name,
} }
for user_id, messages_by_device in local_by_user_then_device.items() for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items() for device_id, message_json in messages_by_device.items()
], ],
) )
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.db_pool.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
await self.db_pool.runWithConnection(reindex_txn)
await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass

View file

@ -0,0 +1,18 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;

View file

@ -0,0 +1,25 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
-- We need to take the max across both device_inbox and device_federation_outbox
-- tables as they share the ID generator
SELECT setval('device_inbox_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
(SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
)
));