to-device messages are tentatively working

This commit is contained in:
Andrew Morgan 2021-10-29 18:34:34 +01:00
parent 95a49727a6
commit eeccdf0e98
6 changed files with 291 additions and 33 deletions

View file

@ -183,7 +183,7 @@ class ApplicationServicesHandler:
self,
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
users: Collection[Union[str, UserID]],
) -> None:
"""
This is called by the notifier in the background when
@ -200,6 +200,8 @@ class ApplicationServicesHandler:
Appservices will only receive ephemeral events that fall within their
registered user and room namespaces.
TODO: Update this bit
Any other value for `stream_key` will cause this function to return early.
new_token: The latest stream token.
@ -208,21 +210,38 @@ class ApplicationServicesHandler:
if not self.notify_appservices:
return
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
return
if stream_key in ("typing_key", "receipt_key", "presence_key"):
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
# to receive these ephemeral events.
return
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
]
if not services:
elif stream_key == "to_device_key":
# Appservices do not need to register explicit support for receiving device list
# updates.
#
# Note that whether these events are actually relevant to these appservices is
# decided later on.
services = self.store.get_app_services()
else:
# This stream_key is not supported.
return
# We only start a new background process if necessary rather than
# optimistically (to cut down on overhead).
self._notify_interested_services_ephemeral(
services, stream_key, new_token, users or []
services, stream_key, new_token, users
)
@wrap_as_background_process("notify_interested_services_ephemeral")
@ -233,7 +252,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
logger.debug("Checking interested services for %s" % stream_key)
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
@ -256,6 +275,8 @@ class ApplicationServicesHandler:
# Persist the latest handled stream token for this appservice
# TODO: We seem to update the stream token for each appservice,
# even if sending the ephemeral events to the appservice failed.
# This is expected for typing, receipt and presence, but will need
# to be handled for device* streams.
await self.store.set_type_stream_id_for_appservice(
service, "read_receipt", new_token
)
@ -270,6 +291,104 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
elif stream_key == "to_device_key" and new_token is not None:
events = await self._handle_to_device(service, new_token, users)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
service, "to_device", new_token
)
async def _handle_to_device(
self,
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
) -> List[JsonDict]:
"""
Given an application service, determine which events it should receive
from those between the last-recorded typing event stream token for this
appservice and the given stream token.
Args:
service: The application service to check for which events it should receive.
new_token: The latest to-device event stream token.
users: The users that should receive new to-device messages.
Returns:
A list of JSON dictionaries containing data derived from the typing events that
should be sent to the given application service.
"""
# Get the stream token that this application service has processed up until
# TODO: Is 'users' always going to be one user here? Sometimes it's the sender!
# TODO: DB migration to add a column for to_device to application_services_state table
from_key = await self.store.get_type_stream_id_for_appservice(
service, "to_device"
)
# Filter out users that this appservice is not interested in
users_appservice_is_interested_in: List[str] = []
for user in users:
if isinstance(user, UserID):
user = user.to_string()
if service.is_interested_in_user(user):
users_appservice_is_interested_in.append(user)
if not users_appservice_is_interested_in:
# Return early if the AS was not interested in any of these users
return []
# Retrieve the to-device messages for each user
(
recipient_user_id_device_id_to_messages,
max_stream_token,
) = await self.store.get_new_messages(
users_appservice_is_interested_in, from_key, new_token, limit=100
)
logger.info(
"*** Users: %s, from: %s, to: %s",
users_appservice_is_interested_in,
from_key,
new_token,
)
logger.info(
"*** Got to-device message: %s", recipient_user_id_device_id_to_messages
)
# TODO: Keep pulling out if max_stream_token != new_token?
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
# to the event JSON so that the application service will know which user/device
# combination this messages was intended for.
#
# So we mangle this dict into a flat list of to-device messages with the relevant
# user ID and device ID embedded inside each message dict.
message_payload: List[JsonDict] = []
for (
user_id,
device_id,
), messages in recipient_user_id_device_id_to_messages.items():
for message_json in messages:
# Remove 'message_id' from the to-device message, as it's an internal ID
message_json.pop("message_id", None)
message_payload.append(
{
"to_user_id": user_id,
"to_device_id": device_id,
**message_json,
}
)
logger.info("*** Ended up with messages: %s", message_payload)
return message_payload
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:

View file

@ -89,6 +89,13 @@ class DeviceMessageHandler:
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
"""
Handle receiving to-device messages from remote homeservers.
Args:
origin: The remote homeserver.
content: The JSON dictionary containing the to-device messages.
"""
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
@ -135,12 +142,16 @@ class DeviceMessageHandler:
message_type, sender_user_id, by_device
)
stream_id = await self.store.add_messages_from_remote_to_device_inbox(
# Add messages to the database.
# Retrieve the stream token of the last-processed to-device message.
max_stream_token = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream token.
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
"to_device_key", max_stream_token, users=local_messages.keys()
)
async def _check_for_unknown_devices(
@ -195,6 +206,14 @@ class DeviceMessageHandler:
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
"""
Handle a request from a user to send to-device message(s).
Args:
requester: The user that is sending the to-device messages.
message_type: The type of to-device messages that are being sent.
messages: A dictionary containing recipients mapped to messages intended for them.
"""
sender_user_id = requester.user.to_string()
message_id = random_string(16)
@ -257,12 +276,16 @@ class DeviceMessageHandler:
"org.matrix.opentracing_context": json_encoder.encode(context),
}
stream_id = await self.store.add_messages_to_device_inbox(
# Add messages to the database.
# Retrieve the stream token of the last-processed to-device message.
max_stream_token = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
# Notify listeners that there are new to-device messages to process,
# handing them the latest stream token.
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
"to_device_key", max_stream_token, users=local_messages.keys()
)
if self.federation_sender:

View file

@ -378,7 +378,7 @@ class Notifier:
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
users: Collection[Union[str, UserID]],
) -> None:
"""Notify application services of ephemeral event activity.
@ -392,7 +392,7 @@ class Notifier:
if isinstance(new_token, int):
stream_token = new_token
self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users or []
stream_key, stream_token, users
)
except Exception:
logger.exception("Error notifying application services of event")

View file

@ -387,7 +387,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence"):
if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: Optional[int]
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
)
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
% stream_id_type,

View file

@ -13,15 +13,16 @@
# limitations under the License.
import logging
from typing import List, Optional, Tuple
from typing import Collection, Dict, List, Optional, Tuple
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -112,31 +113,125 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages(
self,
user_ids: Collection[str],
from_stream_token: int,
to_stream_token: int,
limit: int = 100,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve to-device messages for a given set of user IDs.
Only to-device messages with stream tokens between the given boundaries
(from < X <= to) are returned.
Note that multiple messages can have the same stream token. Stream tokens are
unique to *messages*, but there might be multiple recipients of a message, and
thus multiple entries in the device_inbox table with the same stream token.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_token: The lower boundary of stream token to filter with (exclusive).
to_stream_token: The upper boundary of stream token to filter with (inclusive).
limit: The maximum number of to-device messages to return.
Returns:
A tuple containing the following:
* A list of to-device messages
"""
# Bail out if none of these users have any messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_token
):
break
else:
logger.info("*** Bailing out")
return {}, to_stream_token
def get_new_messages_txn(txn):
# Build a query to select messages from any of the given users that are between
# the given stream token bounds
sql = "SELECT stream_id, user_id, device_id, message_json FROM device_inbox"
# Scope to only the given users. We need to use this method as doing so is
# different across database engines.
many_clause_sql, many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql += (
" WHERE %s"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
) % many_clause_sql
logger.info("*** %s\n\n%s", many_clause_sql, many_clause_args)
logger.info("*** %s", sql)
txn.execute(
sql, (*many_clause_args, from_stream_token, to_stream_token, limit)
)
# Create a dictionary of (user ID, device ID) -> list of messages that
# that device is meant to receive.
recipient_user_id_device_id_to_messages = {}
stream_pos = to_stream_token
total_messages_processed = 0
for row in txn:
# Record the last-processed stream position, to return later.
# Note that we process messages here in order of ascending stream token.
stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
recipient_user_id_device_id_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)
total_messages_processed += 1
# This is needed (REVIEW: I think) as you can have multiple rows for a
# single to-device message (due to multiple recipients).
if total_messages_processed < limit:
stream_pos = to_stream_token
return recipient_user_id_device_id_to_messages, stream_pos
return await self.db_pool.runInteraction(
"get_new_messages", get_new_messages_txn
)
async def get_new_messages_for_device(
self,
user_id: str,
device_id: Optional[str],
last_stream_id: int,
current_stream_id: int,
last_stream_token: int,
current_stream_token: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
"""
Args:
user_id: The recipient user_id.
device_id: The recipient device_id.
last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
last_stream_token: The last stream ID checked.
current_stream_token: The current position of the to device
message stream.
limit: The maximum number of messages to retrieve.
Returns:
A list of messages for the device and where in the stream the messages got to.
A tuple containing:
* A list of messages for the device
* The max stream token of these messages. There may be more to retrieve
if the given limit was reached.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
user_id, last_stream_token
)
if not has_changed:
return [], current_stream_id
return [], current_stream_token
def get_new_messages_for_device_txn(txn):
sql = (
@ -147,14 +242,17 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
sql,
(user_id, device_id, last_stream_token, current_stream_token, limit),
)
messages = []
stream_pos = current_stream_token
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
stream_pos = current_stream_token
return messages, stream_pos
return await self.db_pool.runInteraction(
@ -369,7 +467,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
Args:
local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
Dictionary of recipient user_id to recipient device_id to message.
remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.

View file

@ -0,0 +1,18 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what to_device stream token that this application
-- service has been caught up to.
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;