0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-08 21:58:53 +02:00

Send to-device messages to application services (#11215)

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
This commit is contained in:
Andrew Morgan 2022-02-01 14:13:38 +00:00 committed by GitHub
parent b7282fe7d1
commit 64ec45fc1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 856 additions and 162 deletions

View file

@ -0,0 +1 @@
Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). Disabled by default.

View file

@ -351,11 +351,13 @@ class AppServiceTransaction:
id: int,
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
):
self.service = service
self.id = id
self.events = events
self.ephemeral = ephemeral
self.to_device_messages = to_device_messages
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@ -369,6 +371,7 @@ class AppServiceTransaction:
service=self.service,
events=self.events,
ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages,
txn_id=self.id,
)

View file

@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient):
service: "ApplicationService",
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
txn_id: Optional[int] = None,
) -> bool:
"""
Push data to an application service.
Args:
service: The application service to send to.
events: The persistent events to send.
ephemeral: The ephemeral events to send.
to_device_messages: The to-device messages to send.
txn_id: An unique ID to assign to this transaction. Application services should
deduplicate transactions received with identitical IDs.
Returns:
True if the task succeeded, False if it failed.
"""
if service.url is None:
return True
@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it
body: Dict[str, List[JsonDict]] = {"events": serialized_events}
if service.supports_ephemeral:
body = {
"events": serialized_events,
"de.sorunome.msc2409.ephemeral": ephemeral,
}
else:
body = {"events": serialized_events}
body.update(
{
# TODO: Update to stable prefixes once MSC2409 completes FCP merge.
"de.sorunome.msc2409.ephemeral": ephemeral,
"de.sorunome.msc2409.to_device": to_device_messages,
}
)
try:
await self.put_json(

View file

@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
List,
Optional,
Set,
)
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice.api import ApplicationServiceApi
@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
# Maximum number of ephemeral events to provide in an AS transaction.
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
# Maximum number of to-device messages to provide in an AS transaction.
MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
class ApplicationServiceScheduler:
"""Public facing API for this module. Does the required DI to tie the
@ -97,15 +109,40 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(
self, service: ApplicationService, event: EventBase
def enqueue_for_appservice(
self,
appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
) -> None:
self.queuer.enqueue_event(service, event)
"""
Enqueue some data to be sent off to an application service.
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[JsonDict]
) -> None:
self.queuer.enqueue_ephemeral(service, events)
Args:
appservice: The application service to create and send a transaction to.
events: The persistent room events to send.
ephemeral: The ephemeral events to send.
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages:
return
if events:
self.queuer.queued_events.setdefault(appservice.id, []).extend(events)
if ephemeral:
self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral)
if to_device_messages:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
class _ServiceQueuer:
@ -121,13 +158,15 @@ class _ServiceQueuer:
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
def _start_background_request(self, service: ApplicationService) -> None:
def start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight:
return
@ -136,16 +175,6 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
def enqueue_ephemeral(
self, service: ApplicationService, events: List[JsonDict]
) -> None:
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
async def _send_request(self, service: ApplicationService) -> None:
# sanity-check: we shouldn't get here if this service already has a sender
# running.
@ -162,11 +191,21 @@ class _ServiceQueuer:
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
if not events and not ephemeral:
all_to_device_messages = self.queued_to_device_messages.get(
service.id, []
)
to_device_messages_to_send = all_to_device_messages[
:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send:
return
try:
await self.txn_ctrl.send(service, events, ephemeral)
await self.txn_ctrl.send(
service, events, ephemeral, to_device_messages_to_send
)
except Exception:
logger.exception("AS request failed")
finally:
@ -198,10 +237,24 @@ class _TransactionController:
service: ApplicationService,
events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
application service.
Args:
service: The application service to send the transaction to.
events: The persistent events to include in the transaction.
ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral or []
service=service,
events=events,
ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [],
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View file

@ -52,3 +52,10 @@ class ExperimentalConfig(Config):
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# MSC2409 (this setting only relates to optionally sending to-device messages).
# Presence, typing and read receipt EDUs are already sent to application services that
# have opted in to receive them. If enabled, this adds to-device messages to that list.
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

View file

@ -55,6 +55,9 @@ class ApplicationServicesHandler:
self.clock = hs.get_clock()
self.notify_appservices = hs.config.appservice.notify_appservices
self.event_sources = hs.get_event_sources()
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self.current_max = 0
self.is_processing = False
@ -132,7 +135,9 @@ class ApplicationServicesHandler:
# Fork off pushes to these services
for service in services:
self.scheduler.submit_event_for_as(service, event)
self.scheduler.enqueue_for_appservice(
service, events=[event]
)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@ -199,8 +204,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
value for `stream_key` will cause this function to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@ -216,8 +222,15 @@ class ApplicationServicesHandler:
if not self.notify_appservices:
return
# Ignore any unsupported streams
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
# Notify appservices of updates in ephemeral event streams.
# Only the following streams are currently supported.
# FIXME: We should use constants for these values.
if stream_key not in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
@ -233,6 +246,13 @@ class ApplicationServicesHandler:
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == "to_device_key"
and not self._msc2409_to_device_messages_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
@ -266,7 +286,7 @@ class ApplicationServicesHandler:
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
if stream_key == "typing_key":
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
# Note that we don't persist the token (via set_appservice_stream_type_pos)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
#
@ -274,7 +294,7 @@ class ApplicationServicesHandler:
# and, if they apply to this application service, send it off.
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
continue
# Since we read/update the stream position for this AS/stream
@ -285,28 +305,37 @@ class ApplicationServicesHandler:
):
if stream_key == "receipt_key":
events = await self._handle_receipts(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(
service, events
)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
)
elif stream_key == "presence_key":
events = await self._handle_presence(service, users, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(
service, events
)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_type_stream_id_for_appservice(
await self.store.set_appservice_stream_type_pos(
service, "presence", new_token
)
elif stream_key == "to_device_key":
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
service, new_token, users
)
self.scheduler.enqueue_for_appservice(
service, to_device_messages=to_device_messages
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "to_device", new_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@ -440,6 +469,79 @@ class ApplicationServicesHandler:
return events
async def _get_to_device_messages(
self,
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
) -> List[JsonDict]:
"""
Given an application service, determine which events it should receive
from those between the last-recorded to-device message stream token for this
appservice and the given stream token.
Args:
service: The application service to check for which events it should receive.
new_token: The latest to-device event stream token.
users: The users to be notified for the new to-device messages
(ie, the recipients of the messages).
Returns:
A list of JSON dictionaries containing data derived from the to-device events
that should be sent to the given application service.
"""
# Get the stream token that this application service has processed up until
from_key = await self.store.get_type_stream_id_for_appservice(
service, "to_device"
)
# Filter out users that this appservice is not interested in
users_appservice_is_interested_in: List[str] = []
for user in users:
# FIXME: We should do this farther up the call stack. We currently repeat
# this operation in _handle_presence.
if isinstance(user, UserID):
user = user.to_string()
if service.is_interested_in_user(user):
users_appservice_is_interested_in.append(user)
if not users_appservice_is_interested_in:
# Return early if the AS was not interested in any of these users
return []
# Retrieve the to-device messages for each user
recipient_device_to_messages = await self.store.get_messages_for_user_devices(
users_appservice_is_interested_in,
from_key,
new_token,
)
# According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
# to the event JSON so that the application service will know which user/device
# combination this messages was intended for.
#
# So we mangle this dict into a flat list of to-device messages with the relevant
# user ID and device ID embedded inside each message dict.
message_payload: List[JsonDict] = []
for (
user_id,
device_id,
), messages in recipient_device_to_messages.items():
for message_json in messages:
# Remove 'message_id' from the to-device message, as it's an internal ID
message_json.pop("message_id", None)
message_payload.append(
{
"to_user_id": user_id,
"to_device_id": device_id,
**message_json,
}
)
return message_payload
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.

View file

@ -1348,8 +1348,8 @@ class SyncHandler:
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id != int(now_token.to_device_key):
messages, stream_id = await self.store.get_new_messages_for_device(
if device_id is not None and since_stream_id != int(now_token.to_device_key):
messages, stream_id = await self.store.get_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)

View file

@ -461,7 +461,9 @@ class Notifier:
users,
)
except Exception:
logger.exception("Error notifying application services of event")
logger.exception(
"Error notifying application services of ephemeral events"
)
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened

View file

@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
service: The service who the transaction is for.
events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
Returns:
A new transaction.
@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events, ephemeral=ephemeral
service=service,
id=new_txn_id,
events=events,
ephemeral=ephemeral,
to_device_messages=to_device_messages,
)
return await self.db_pool.runInteraction(
@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
service=service,
id=entry["txn_id"],
events=events,
ephemeral=[],
to_device_messages=[],
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence"):
if type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
async def set_type_stream_id_for_appservice(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if stream_type not in ("read_receipt", "presence"):
if stream_type not in ("read_receipt", "presence", "to_device"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
)
def set_type_stream_id_for_appservice_txn(txn):
def set_appservice_stream_type_pos_txn(txn):
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
)
await self.db_pool.runInteraction(
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
"set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@ -24,6 +24,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages_for_device(
async def get_messages_for_user_devices(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
) -> Dict[Tuple[str, str], List[JsonDict]]:
"""
Retrieve to-device messages for a given set of users.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Args:
user_ids: The users to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
Returns:
A dictionary of (user id, device id) -> list of to-device messages.
"""
# We expect the stream ID returned by _get_device_messages to always
# be to_stream_id. So, no need to return it from this function.
(
user_id_device_id_to_messages,
last_processed_stream_id,
) = await self._get_device_messages(
user_ids=user_ids,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
)
assert (
last_processed_stream_id == to_stream_id
), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
return user_id_device_id_to_messages
async def get_messages_for_device(
self,
user_id: str,
device_id: Optional[str],
last_stream_id: int,
current_stream_id: int,
device_id: str,
from_stream_id: int,
to_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
) -> Tuple[List[JsonDict], int]:
"""
Retrieve to-device messages for a single user device.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Args:
user_id: The recipient user_id.
device_id: The recipient device_id.
last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
message stream.
limit: The maximum number of messages to retrieve.
user_id: The ID of the user to retrieve messages for.
device_id: The ID of the device to retrieve to-device messages for.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
limit: A limit on the number of to-device messages returned.
Returns:
A tuple containing:
* A list of messages for the device.
* The max stream token of these messages. There may be more to retrieve
if the given limit was reached.
* A list of to-device messages within the given stream id range intended for
the given user / device combo.
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
(
user_id_device_id_to_messages,
last_processed_stream_id,
) = await self._get_device_messages(
user_ids=[user_id],
device_id=device_id,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
limit=limit,
)
if not has_changed:
return [], current_stream_id
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
if not user_id_device_id_to_messages:
# There were no messages!
return [], to_stream_id
# Extract the messages, no need to return the user and device ID again
to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
return to_device_messages, last_processed_stream_id
async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
device_id: Optional[str] = None,
limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
Only to-device messages with stream ids between the given boundaries
(from < X <= to) are returned.
Note that a stream ID can be shared by multiple copies of the same message with
different recipient devices. Stream IDs are only unique in the context of a single
user ID / device ID pair. Thus, applying a limit (of messages to return) when working
with a sliding window of stream IDs is only possible when querying messages of a
single user device.
Finally, note that device IDs are not unique across users.
Args:
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
device_id: A device ID to query to-device messages for. If not provided, to-device
messages from all device IDs for the given user IDs will be queried. May not be
provided if `user_ids` contains more than one entry.
limit: The maximum number of to-device messages to return. Can only be used when
passing a single user ID / device ID tuple.
Returns:
A tuple containing:
* A dict of (user_id, device_id) -> list of to-device messages
* The last-processed stream ID. If this is less than `to_stream_id`, then
there may be more messages to retrieve. If `limit` is not set, then this
is always equal to 'to_stream_id'.
"""
if not user_ids:
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id
# Prevent a query for one user's device also retrieving another user's device with
# the same device ID (device IDs are not unique across users).
if len(user_ids) > 1 and device_id is not None:
raise AssertionError(
"Programming error: 'device_id' cannot be supplied to "
"_get_device_messages when >1 user_id has been provided"
)
messages = []
stream_pos = current_stream_id
# A limit can only be applied when querying for a single user ID / device ID tuple.
# See the docstring of this function for more details.
if limit is not None and device_id is None:
raise AssertionError(
"Programming error: _get_device_messages was passed 'limit' "
"without a specific user_id/device_id"
)
user_ids_to_query: Set[str] = set()
device_ids_to_query: Set[str] = set()
# Note that a device ID could be an empty str
if device_id is not None:
# If a device ID was passed, use it to filter results.
# Otherwise, device IDs will be derived from the given collection of user IDs.
device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages
for user_id in user_ids:
if self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
# This user has new messages sent to them. Query messages for them
user_ids_to_query.add(user_id)
def get_device_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
# If a list of device IDs was not provided, retrieve all devices IDs
# for the given users. We explicitly do not query hidden devices, as
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"user_id": user_id, "hidden": False},
retcols=("device_id",),
)
device_ids_to_query.update(
{row["device_id"] for row in user_device_dicts}
)
if not device_ids_to_query:
# We've ended up with no devices to query.
return {}, to_stream_id
# We include both user IDs and device IDs in this query, as we have an index
# (device_inbox_user_stream_id) for them.
user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids_to_query
)
(
device_id_many_clause_sql,
device_id_many_clause_args,
) = make_in_list_sql_clause(
self.database_engine, "device_id", device_ids_to_query
)
sql = f"""
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
WHERE {user_id_many_clause_sql}
AND {device_id_many_clause_sql}
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
sql_args = (
*user_id_many_clause_args,
*device_id_many_clause_args,
from_stream_id,
to_stream_id,
)
# If a limit was provided, limit the data retrieved from the database
if limit is not None:
sql += "LIMIT ?"
sql_args += (limit,)
txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
# If the limit was not reached we know that there's no more data for this
# user/device pair up to current_stream_id.
if len(messages) < limit:
stream_pos = current_stream_id
# Store the device details
recipient_device_to_messages.setdefault(
(recipient_user_id, recipient_device_id), []
).append(message_dict)
return messages, stream_pos
if limit is not None and txn.rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return recipient_device_to_messages, last_processed_stream_pos
# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
"get_device_messages", get_device_messages_txn
)
@trace

View file

@ -0,0 +1,21 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what to_device stream id that this application
-- service has been caught up to.
-- NULL indicates that this appservice has never received any to_device messages. This
-- can be used, for example, to avoid sending a huge dump of messages at startup.
ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;

View file

@ -11,23 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from unittest.mock import Mock
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from synapse.appservice.scheduler import (
ApplicationServiceScheduler,
_Recoverer,
_ServiceQueuer,
_TransactionController,
)
from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import simple_async_mock
from ..utils import MockClock
if TYPE_CHECKING:
from twisted.internet.testing import MemoryReactor
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self):
@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[] # txn made and saved
service=service,
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[] # txn made and saved
service=service,
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[]
service=service, events=events, ephemeral=[], to_device_messages=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.callback.assert_called_once_with(self.recoverer)
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer):
self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
# Replace instantiated _TransactionController instances with our Mock
self.scheduler.txn_ctrl = self.txn_ctrl
self.scheduler.queuer.txn_ctrl = self.txn_ctrl
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
self.queuer.enqueue_event(service, event)
self.txn_ctrl.send.assert_called_once_with(service, [event], [])
self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
self.queuer.enqueue_event(service, event)
self.scheduler.enqueue_for_appservice(service, events=[event])
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_event(service, event2)
self.queuer.enqueue_event(service, event3)
self.txn_ctrl.send.assert_called_with(service, [event], [])
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z):
def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
self.queuer.enqueue_event(srv1, srv_1_event)
self.queuer.enqueue_event(srv1, srv_1_event2)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
self.queuer.enqueue_event(srv2, srv_2_event)
self.queuer.enqueue_event(srv2, srv_2_event2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z):
def do_send(*args, **kwargs):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list:
self.queuer.enqueue_event(service, event)
self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet.
self.queuer.enqueue_ephemeral(service, event_list_1)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1)
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_ephemeral(service, event_list_2)
self.queuer.enqueue_ephemeral(service, event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, []
)
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d))
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
self.assertEquals(2, self.txn_ctrl.send.call_count)

View file

@ -1,4 +1,4 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,18 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from twisted.internet import defer
import synapse.rest.admin
import synapse.storage
from synapse.appservice import ApplicationService
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.types import RoomStreamToken
from synapse.util.stringutils import random_string
from tests.test_utils import make_awaitable
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
from tests.utils import MockClock
from .. import unittest
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""
@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_datastore.return_value = self.mock_store
self.mock_store.get_received_ts.return_value = make_awaitable(0)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
None
)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, events=[event]
)
def test_query_user_exists_unknown_user(self):
@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"""
interested_service = self._mkservice(is_interested=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
interested_service, [event]
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
)
self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with(
self.mock_store.set_appservice_stream_type_pos.assert_called_once_with(
interested_service,
"read_receipt",
580,
@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler.notify_interested_services_ephemeral(
"receipt_key", 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[]
)
def _mkservice(self, is_interested, protocols=None):
service = Mock()
@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"""
Tests that the ApplicationServicesHandler sends events to application
services correctly.
"""
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
sendtodevice.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
# A user on the homeserver.
self.local_user_device_id = "local_device"
self.local_user = self.register_user("local_user", "password")
self.local_user_token = self.login(
"local_user", "password", self.local_user_device_id
)
# A user on the homeserver which lies within an appservice's exclusive user namespace.
self.exclusive_as_user_device_id = "exclusive_as_device"
self.exclusive_as_user = self.register_user("exclusive_as_user", "password")
self.exclusive_as_user_token = self.login(
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_local_to_device(self):
"""
Test that when a user sends a to-device message to another user
that is an application service's user namespace, the
application service will receive it.
"""
interested_appservice = self._register_application_service(
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@exclusive_as_user:.+",
"exclusive": True,
}
],
},
)
# Have local_user send a to-device message to exclusive_as_user
message_content = {"some_key": "some really interesting value"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={
"messages": {
self.exclusive_as_user: {
self.exclusive_as_user_device_id: message_content
}
}
},
access_token=self.local_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
# Have exclusive_as_user send a to-device message to local_user
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
content={
"messages": {
self.local_user: {self.local_user_device_id: message_content}
}
},
access_token=self.exclusive_as_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
# Check if our application service - that is interested in exclusive_as_user - received
# the to-device message as part of an AS transaction.
# Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS.
#
# The uninterested application service should not have been notified at all.
self.send_mock.assert_called_once()
service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
self.assertEqual(service, interested_appservice)
self.assertEqual(to_device_messages[0]["type"], "m.room_key_request")
self.assertEqual(to_device_messages[0]["sender"], self.local_user)
# Additional fields 'to_user_id' and 'to_device_id' specifically for
# to-device messages via the AS API
self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user)
self.assertEqual(
to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id
)
self.assertEqual(to_device_messages[0]["content"], message_content)
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_bursts_of_to_device(self):
"""
Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions.
Also tests that uninterested application services do not receive messages.
"""
# Register two application services with exclusive interest in a user
interested_appservices = []
for _ in range(2):
appservice = self._register_application_service(
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@exclusive_as_user:.+",
"exclusive": True,
}
],
},
)
interested_appservices.append(appservice)
# ...and an application service which does not have any user interest.
self._register_application_service()
to_device_message_content = {
"some key": "some interesting value",
}
# We need to send a large burst of to-device messages. We also would like to
# include them all in the same application service transaction so that we can
# test large transactions.
#
# To do this, we can send a single to-device message to many user devices at
# once.
#
# We insert number_of_messages - 1 messages into the database directly. We'll then
# send a final to-device message to the real device, which will also kick off
# an AS transaction (as just inserting messages into the DB won't).
number_of_messages = 150
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
messages = {
self.exclusive_as_user: {
device_id: to_device_message_content for device_id in fake_device_ids
}
}
# Create a fake device per message. We can't send to-device messages to
# a device that doesn't exist.
self.get_success(
self.hs.get_datastore().db_pool.simple_insert_many(
desc="test_application_services_receive_burst_of_to_device",
table="devices",
keys=("user_id", "device_id"),
values=[
(
self.exclusive_as_user,
device_id,
)
for device_id in fake_device_ids
],
)
)
# Seed the device_inbox table with our fake messages
self.get_success(
self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
)
# Now have local_user send a final to-device message to exclusive_as_user. All unsent
# to-device messages should be sent to any application services
# interested in exclusive_as_user.
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/4",
content={
"messages": {
self.exclusive_as_user: {
self.exclusive_as_user_device_id: to_device_message_content
}
}
},
access_token=self.local_user_token,
)
self.assertEqual(chan.code, 200, chan.result)
self.send_mock.assert_called()
# Count the total number of to-device messages that were sent out per-service.
# Ensure that we only sent to-device messages to interested services, and that
# each interested service received the full count of to-device messages.
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
# Add to the count of messages for this application service
service_id_to_message_count.setdefault(service.id, 0)
service_id_to_message_count[service.id] += len(to_device_messages)
# Assert that each interested service received the full count of messages
for count in service_id_to_message_count.values():
self.assertEqual(count, number_of_messages)
def _register_application_service(
self,
namespaces: Optional[Dict[str, Iterable[Dict]]] = None,
) -> ApplicationService:
"""
Register a new application service, with the given namespaces of interest.
Args:
namespaces: A dictionary containing any user, room or alias namespaces that
the application service is interested in.
Returns:
The registered application service.
"""
# Create an application service
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces=namespaces,
supports_ephemeral=True,
)
# Register the application service
self._services.append(appservice)
return appservice

View file

@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(self.store.create_appservice_txn(service, events, []))
defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [])
)
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
service = Mock(id=self.as_list[0]["id"])
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events))
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(self.store.create_appservice_txn(service, events, []))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
ValueError,
)
def test_set_type_stream_id_for_appservice(self) -> None:
def test_set_appservice_stream_type_pos(self) -> None:
read_receipt_value = 1024
self.get_success(
self.store.set_type_stream_id_for_appservice(
self.store.set_appservice_stream_type_pos(
self.service, "read_receipt", read_receipt_value
)
)
@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
self.assertEqual(result, read_receipt_value)
self.get_success(
self.store.set_type_stream_id_for_appservice(
self.store.set_appservice_stream_type_pos(
self.service, "presence", read_receipt_value
)
)
@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
)
self.assertEqual(result, read_receipt_value)
def test_set_type_stream_id_for_appservice_invalid_type(self) -> None:
def test_set_appservice_stream_type_pos_invalid_type(self) -> None:
self.get_failure(
self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024),
self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024),
ValueError,
)