0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 06:52:03 +01:00

Send device list updates to application services (MSC3202) - part 1 (#11881)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Andrew Morgan 2022-03-30 14:39:27 +01:00 committed by GitHub
parent 2fc15ac718
commit d8d0271977
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 490 additions and 82 deletions

View file

@ -0,0 +1 @@
Send device list changes to application services as specified by [MSC3202](https://github.com/matrix-org/matrix-spec-proposals/pull/3202), using unstable prefixes. The `msc3202_transaction_extensions` experimental homeserver config option must be enabled and `org.matrix.msc3202: true` must be present in the application service registration file for device list changes to be sent. The "left" field is currently always empty.

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,7 +23,13 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id from synapse.types import (
DeviceListUpdates,
GroupID,
JsonDict,
UserID,
get_domain_from_id,
)
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -400,6 +407,7 @@ class AppServiceTransaction:
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts, one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
): ):
self.service = service self.service = service
self.id = id self.id = id
@ -408,6 +416,7 @@ class AppServiceTransaction:
self.to_device_messages = to_device_messages self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys self.unused_fallback_keys = unused_fallback_keys
self.device_list_summary = device_list_summary
async def send(self, as_api: "ApplicationServiceApi") -> bool: async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface. """Sends this transaction using the provided AS API interface.
@ -424,6 +433,7 @@ class AppServiceTransaction:
to_device_messages=self.to_device_messages, to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts, one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys, unused_fallback_keys=self.unused_fallback_keys,
device_list_summary=self.device_list_summary,
txn_id=self.id, txn_id=self.id,
) )

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -27,7 +28,7 @@ from synapse.appservice import (
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING: if TYPE_CHECKING:
@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts, one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
txn_id: Optional[int] = None, txn_id: Optional[int] = None,
) -> bool: ) -> bool:
""" """
@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
} }
) )
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions: if service.msc3202_transaction_extensions:
if one_time_key_counts: if one_time_key_counts:
body[ body[
@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
body[ body[
"org.matrix.msc3202.device_unused_fallback_keys" "org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys ] = unused_fallback_keys
if device_list_summary:
body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed),
"left": list(device_list_summary.left),
}
try: try:
await self.put_json( await self.put_json(

View file

@ -72,7 +72,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock from synapse.util import Clock
if TYPE_CHECKING: if TYPE_CHECKING:
@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
events: Optional[Collection[EventBase]] = None, events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None, ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None, to_device_messages: Optional[Collection[JsonDict]] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None: ) -> None:
""" """
Enqueue some data to be sent off to an application service. Enqueue some data to be sent off to an application service.
@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
to_device_messages: The to-device messages to send. These differ from normal to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields. 'to_user_id' fields.
device_list_summary: A summary of users that the application service either needs
to refresh the device lists of, or those that the application service need no
longer track the device lists of.
""" """
# We purposefully allow this method to run with empty events/ephemeral # We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves. # collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages: if (
not events
and not ephemeral
and not to_device_messages
and not device_list_summary
):
return return
if events: if events:
@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages to_device_messages
) )
if device_list_summary:
self.queuer.queued_device_list_summaries.setdefault(
appservice.id, []
).append(device_list_summary)
# Kick off a new application service transaction # Kick off a new application service transaction
self.queuer.start_background_request(appservice) self.queuer.start_background_request(appservice)
@ -169,6 +182,8 @@ class _ServiceQueuer:
self.queued_ephemeral: Dict[str, List[JsonDict]] = {} self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]} # dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
# the appservices which currently have a transaction in flight # the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set() self.requests_in_flight: Set[str] = set()
@ -212,7 +227,35 @@ class _ServiceQueuer:
] ]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send: # Consolidate any pending device list summaries into a single, up-to-date
# summary.
# Note: this code assumes that in a single DeviceListUpdates, a user will
# never be in both "changed" and "left" sets.
device_list_summary = DeviceListUpdates()
for summary in self.queued_device_list_summaries.get(service.id, []):
# For every user in the incoming "changed" set:
# * Remove them from the existing "left" set if necessary
# (as we need to start tracking them again)
# * Add them to the existing "changed" set if necessary.
device_list_summary.left.difference_update(summary.changed)
device_list_summary.changed.update(summary.changed)
# For every user in the incoming "left" set:
# * Remove them from the existing "changed" set if necessary
# (we no longer need to track them)
# * Add them to the existing "left" set if necessary.
device_list_summary.changed.difference_update(summary.left)
device_list_summary.left.update(summary.left)
self.queued_device_list_summaries.clear()
if (
not events
and not ephemeral
and not to_device_messages_to_send
# DeviceListUpdates is True if either the 'changed' or 'left' sets have
# at least one entry, otherwise False
and not device_list_summary
):
return return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
@ -240,6 +283,7 @@ class _ServiceQueuer:
to_device_messages_to_send, to_device_messages_to_send,
one_time_key_counts, one_time_key_counts,
unused_fallback_keys, unused_fallback_keys,
device_list_summary,
) )
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
@ -322,6 +366,7 @@ class _TransactionController:
to_device_messages: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None, one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None: ) -> None:
""" """
Create a transaction with the given data and send to the provided Create a transaction with the given data and send to the provided
@ -336,6 +381,7 @@ class _TransactionController:
appservice devices in the transaction. appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction. appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
""" """
try: try:
txn = await self.store.create_appservice_txn( txn = await self.store.create_appservice_txn(
@ -345,6 +391,7 @@ class _TransactionController:
to_device_messages=to_device_messages or [], to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {}, one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {}, unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceListUpdates(),
) )
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:

View file

@ -170,6 +170,7 @@ def _load_appservice(
# When enabled, appservice transactions contain the following information: # When enabled, appservice transactions contain the following information:
# - device One-Time Key counts # - device One-Time Key counts
# - device unused fallback key usage states # - device unused fallback key usage states
# - device list changes
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False) msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool): if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError( raise ValueError(

View file

@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
"msc3202_device_masquerading", False "msc3202_device_masquerading", False
) )
# Portion of MSC3202 related to transaction extensions: # The portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services. # sending device list changes, one-time key counts and fallback key
# usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get( self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False "msc3202_transaction_extensions", False
) )

View file

@ -33,7 +33,13 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.types import (
DeviceListUpdates,
JsonDict,
RoomAlias,
RoomStreamToken,
UserID,
)
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -58,6 +64,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = ( self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled hs.config.experimental.msc2409_to_device_messages_enabled
) )
self._msc3202_transaction_extensions_enabled = (
hs.config.experimental.msc3202_transaction_extensions
)
self.current_max = 0 self.current_max = 0
self.is_processing = False self.is_processing = False
@ -204,9 +213,9 @@ class ApplicationServicesHandler:
Args: Args:
stream_key: The stream the event came from. stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or `stream_key` can be "typing_key", "receipt_key", "presence_key",
"to_device_key". Any other value for `stream_key` will cause this function "to_device_key" or "device_list_key". Any other value for `stream_key`
to return early. will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration receiving them by setting `push_ephemeral` to true in their registration
@ -230,6 +239,7 @@ class ApplicationServicesHandler:
"receipt_key", "receipt_key",
"presence_key", "presence_key",
"to_device_key", "to_device_key",
"device_list_key",
): ):
return return
@ -253,15 +263,37 @@ class ApplicationServicesHandler:
): ):
return return
# Ignore device lists if the feature flag is not enabled
if (
stream_key == "device_list_key"
and not self._msc3202_transaction_extensions_enabled
):
return
# Check whether there are any appservices which have registered to receive # Check whether there are any appservices which have registered to receive
# ephemeral events. # ephemeral events.
# #
# Note that whether these events are actually relevant to these appservices # Note that whether these events are actually relevant to these appservices
# is decided later on. # is decided later on.
services = self.store.get_app_services()
services = [ services = [
service service
for service in self.store.get_app_services() for service in services
if service.supports_ephemeral # Different stream keys require different support booleans
if (
stream_key
in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
)
and service.supports_ephemeral
)
or (
stream_key == "device_list_key"
and service.msc3202_transaction_extensions
)
] ]
if not services: if not services:
# Bail out early if none of the target appservices have explicitly registered # Bail out early if none of the target appservices have explicitly registered
@ -336,6 +368,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token service, "to_device", new_token
) )
elif stream_key == "device_list_key":
device_list_summary = await self._get_device_list_summary(
service, new_token
)
if device_list_summary:
self.scheduler.enqueue_for_appservice(
service, device_list_summary=device_list_summary
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "device_list", new_token
)
async def _handle_typing( async def _handle_typing(
self, service: ApplicationService, new_token: int self, service: ApplicationService, new_token: int
) -> List[JsonDict]: ) -> List[JsonDict]:
@ -542,6 +588,96 @@ class ApplicationServicesHandler:
return message_payload return message_payload
async def _get_device_list_summary(
self,
appservice: ApplicationService,
new_key: int,
) -> DeviceListUpdates:
"""
Retrieve a list of users who have changed their device lists.
Args:
appservice: The application service to retrieve device list changes for.
new_key: The stream key of the device list change that triggered this method call.
Returns:
A set of device list updates, comprised of users that the appservices needs to:
* resync the device list of, and
* stop tracking the device list of.
"""
# Fetch the last successfully processed device list update stream ID
# for this appservice.
from_key = await self.store.get_type_stream_id_for_appservice(
appservice, "device_list"
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
)
# Filter out any users the application service is not interested in
#
# For each user who changed their device list, we want to check whether this
# appservice would be interested in the change.
filtered_users_with_changed_device_lists = {
user_id
for user_id in users_with_changed_device_lists
if await self._is_appservice_interested_in_device_lists_of_user(
appservice, user_id
)
}
# Create a summary of "changed" and "left" users.
# TODO: Calculate "left" users.
device_list_summary = DeviceListUpdates(
changed=filtered_users_with_changed_device_lists
)
return device_list_summary
async def _is_appservice_interested_in_device_lists_of_user(
self,
appservice: ApplicationService,
user_id: str,
) -> bool:
"""
Returns whether a given application service is interested in the device list
updates of a given user.
The application service is interested in the user's device list updates if any
of the following are true:
* The user is the appservice's sender localpart user.
* The user is in the appservice's user namespace.
* At least one member of one room that the user is a part of is in the
appservice's user namespace.
* The appservice is explicitly (via room ID or alias) interested in at
least one room that the user is in.
Args:
appservice: The application service to gauge interest of.
user_id: The ID of the user whose device list interest is in question.
Returns:
True if the application service is interested in the user's device lists, False
otherwise.
"""
# This method checks against both the sender localpart user as well as if the
# user is in the appservice's user namespace.
if appservice.is_interested_in_user(user_id):
return True
# Determine whether any of the rooms the user is in justifies sending this
# device list update to the application service.
room_ids = await self.store.get_rooms_for_user(user_id)
for room_id in room_ids:
# This method covers checking room members for appservice interest as well as
# room ID and alias checks.
if await appservice.is_interested_in_room(room_id, self.store):
return True
return False
async def query_user_exists(self, user_id: str) -> bool: async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists. """Check if any application service knows this user_id exists.

View file

@ -13,17 +13,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
TYPE_CHECKING,
Any,
Collection,
Dict,
FrozenSet,
List,
Optional,
Set,
Tuple,
)
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -41,6 +31,7 @@ from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
DeviceListUpdates,
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
Requester, Requester,
@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave) return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: List of user_ids whose devices may have changed
left: List of user_ids whose devices we no longer track
"""
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class _RoomChanges: class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined """The set of room entries to include in the sync, plus the set of joined
@ -240,7 +216,7 @@ class SyncResult:
knocked: List[KnockedSyncResult] knocked: List[KnockedSyncResult]
archived: List[ArchivedSyncResult] archived: List[ArchivedSyncResult]
to_device: List[JsonDict] to_device: List[JsonDict]
device_lists: DeviceLists device_lists: DeviceListUpdates
device_one_time_keys_count: JsonDict device_one_time_keys_count: JsonDict
device_unused_fallback_key_types: List[str] device_unused_fallback_key_types: List[str]
groups: Optional[GroupsSyncResult] groups: Optional[GroupsSyncResult]
@ -1264,8 +1240,8 @@ class SyncHandler:
newly_joined_or_invited_or_knocked_users: Set[str], newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str], newly_left_rooms: Set[str],
newly_left_users: Set[str], newly_left_users: Set[str],
) -> DeviceLists: ) -> DeviceListUpdates:
"""Generate the DeviceLists section of sync """Generate the DeviceListUpdates section of sync
Args: Args:
sync_result_builder sync_result_builder
@ -1383,9 +1359,11 @@ class SyncHandler:
if any(e.room_id in joined_rooms for e in entries): if any(e.room_id in joined_rooms for e in entries):
newly_left_users.discard(user_id) newly_left_users.discard(user_id)
return DeviceLists(changed=users_that_have_changed, left=newly_left_users) return DeviceListUpdates(
changed=users_that_have_changed, left=newly_left_users
)
else: else:
return DeviceLists(changed=[], left=[]) return DeviceListUpdates()
async def _generate_sync_entry_for_to_device( async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder" self, sync_result_builder: "SyncResultBuilder"

View file

@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts, one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys, unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction: ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service """Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the with the given list of events. Ephemeral events are NOT persisted to the
@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction. appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction. appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
Returns: Returns:
A new transaction. A new transaction.
@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages, to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts, one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys, unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys # TODO: to-device messages, one-time key counts, device list summaries and unused
# are not yet populated for catch-up transactions. # fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability. # We likely want to populate those for reliability.
return AppServiceTransaction( return AppServiceTransaction(
service=service, service=service,
@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[], to_device_messages=[],
one_time_key_counts={}, one_time_key_counts={},
unused_fallback_keys={}, unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
) )
def _get_last_txn(self, txn, service_id: Optional[str]) -> int: def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice( async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str self, service: ApplicationService, type: str
) -> int: ) -> int:
if type not in ("read_receipt", "presence", "to_device"): if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError( raise ValueError(
"Expected type to be a valid application stream id type, got %s" "Expected type to be a valid application stream id type, got %s"
% (type,) % (type,)
@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos( async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int] self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None: ) -> None:
if stream_type not in ("read_receipt", "presence", "to_device"): if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError( raise ValueError(
"Expected type to be a valid application stream id type, got %s" "Expected type to be a valid application stream id type, got %s"
% (stream_type,) % (stream_type,)

View file

@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key) return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str] self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]: ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that """Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids. are in the given list of user_ids.
Args: Args:
from_key: The device lists stream token from_key: The minimum device lists stream token to query device list changes for,
user_ids: The user IDs to query for devices. exclusive.
user_ids: If provided, only check if these users have changed their device lists.
Otherwise changes from all users are returned.
to_key: The maximum device lists stream token to query device list changes for,
inclusive.
Returns: Returns:
The set of user_ids whose devices have changed since `from_key` The set of user_ids whose devices have changed since `from_key` (exclusive)
until `to_key` (inclusive).
""" """
# Get set of users who *may* have changed. Users not in the returned # Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed. # list have definitely not changed.
to_check = self._device_list_stream_cache.get_entities_changed( if user_ids is None:
user_ids, from_key # Get set of all users that have had device list changes since 'from_key'
) user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if not to_check: if not user_ids_to_check:
return set() return set()
def _get_users_whose_devices_changed_txn(txn): def _get_users_whose_devices_changed_txn(txn):
changes = set() changes = set()
sql = """ stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ? WHERE {stream_id_where_clause}
AND AND
""" """
for chunk in batch_iter(to_check, 100): # Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk txn.database_engine, "user_id", chunk
) )
txn.execute(sql + clause, (from_key,) + tuple(args)) txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn) changes.update(user_id for user_id, in txn)
return changes return changes

View file

@ -0,0 +1,23 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what device list changes stream id that this application
-- service has been caught up to.
-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible
-- state is useful for determining if we've ever sent traffic for a stream type
-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for
-- one way this can be used.
ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;

View file

@ -25,6 +25,7 @@ from typing import (
Match, Match,
MutableMapping, MutableMapping,
Optional, Optional,
Set,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -748,6 +749,30 @@ class ReadReceipt:
data: JsonDict data: JsonDict
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListUpdates:
"""
An object containing a diff of information regarding other users' device lists, intended for
a recipient to carry out device list tracking.
Attributes:
changed: A set of users whose device lists have changed recently.
left: A set of users who the recipient no longer needs to track the device lists of.
Typically when those users no longer share any end-to-end encryption enabled rooms.
"""
# We need to use a factory here, otherwise `set` is not evaluated at
# object instantiation, but instead at class definition instantiation.
# The latter happening only once, thus always giving you the same sets
# across multiple DeviceListUpdates instances.
# Also see: don't define mutable default arguments.
changed: Set[str] = attr.ib(factory=set)
left: Set[str] = attr.ib(factory=set)
def __bool__(self) -> bool:
return bool(self.changed or self.left)
def get_verify_key_from_cross_signing_key(key_info): def get_verify_key_from_cross_signing_key(key_info):
"""Get the key ID and signedjson verify key from a cross-signing key dict """Get the key ID and signedjson verify key from a cross-signing key dict

View file

@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
) )
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import DeviceListUpdates
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved to_device_messages=[], # txn made and saved
one_time_key_counts={}, one_time_key_counts={},
unused_fallback_keys={}, unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
) )
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed txn.complete.assert_called_once_with(self.store) # txn completed
@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved to_device_messages=[], # txn made and saved
one_time_key_counts={}, one_time_key_counts={},
unused_fallback_keys={}, unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
) )
self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.send.call_count) # txn not sent though
self.assertEqual(0, txn.complete.call_count) # or completed self.assertEqual(0, txn.complete.call_count) # or completed
@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], to_device_messages=[],
one_time_key_counts={}, one_time_key_counts={},
unused_fallback_keys={}, unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
) )
self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made
self.assertEqual(1, self.recoverer.recover.call_count) # and invoked self.assertEqual(1, self.recoverer.recover.call_count) # and invoked
@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4) service = Mock(id=4)
event = Mock() event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event]) self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) self.txn_ctrl.send.assert_called_once_with(
service, [event], [], [], None, None, DeviceListUpdates()
)
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately) # (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3]) self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) self.txn_ctrl.send.assert_called_with(
service, [event], [], [], None, None, DeviceListUpdates()
)
self.assertEqual(1, self.txn_ctrl.send.call_count) self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent # Resolve the send event: expect the queued events to be sent
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], None, None service, [event2, event3], [], [], None, None, DeviceListUpdates()
) )
self.assertEqual(2, self.txn_ctrl.send.call_count) self.assertEqual(2, self.txn_ctrl.send.call_count)
@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.txn_ctrl.send.assert_called_with(
srv1, [srv_1_event], [], [], None, None, DeviceListUpdates()
)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event], [], [], None, None, DeviceListUpdates()
)
# make sure callbacks for a service only send queued events for THAT # make sure callbacks for a service only send queued events for THAT
# service # service
srv_2_defer.callback(srv2) srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates()
)
self.assertEqual(3, self.txn_ctrl.send.call_count) self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self): def test_send_large_txns(self):
@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Expect the first event to be sent immediately. # Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], None, None service, [event_list[0]], [], [], None, None, DeviceListUpdates()
) )
srv_1_defer.callback(service) srv_1_defer.callback(service)
# Then send the next 100 events # Then send the next 100 events
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], None, None service, event_list[1:101], [], [], None, None, DeviceListUpdates()
) )
srv_2_defer.callback(service) srv_2_defer.callback(service)
# Then the final 99 events # Then the final 99 events
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], None, None service, event_list[101:], [], [], None, None, DeviceListUpdates()
) )
self.assertEqual(3, self.txn_ctrl.send.call_count) self.assertEqual(3, self.txn_ctrl.send.call_count)
@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event")] event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with( self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None service, [], event_list, [], None, None, DeviceListUpdates()
) )
def test_send_multiple_ephemeral_no_queue(self): def test_send_multiple_ephemeral_no_queue(self):
@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with( self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None service, [], event_list, [], None, None, DeviceListUpdates()
) )
def test_send_single_ephemeral_with_queue(self): def test_send_single_ephemeral_with_queue(self):
@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times. # Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) self.txn_ctrl.send.assert_called_with(
service, [], event_list_1, [], None, None, DeviceListUpdates()
)
self.assertEqual(1, self.txn_ctrl.send.call_count) self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send # Resolve txn_ctrl.send
d.callback(service) d.callback(service)
# Expect the queued events to be sent # Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, [], None, None service,
[],
event_list_2 + event_list_3,
[],
None,
None,
DeviceListUpdates(),
) )
self.assertEqual(2, self.txn_ctrl.send.call_count) self.assertEqual(2, self.txn_ctrl.send.call_count)
@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = first_chunk + second_chunk event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with( self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], None, None service, [], first_chunk, [], None, None, DeviceListUpdates()
) )
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) self.txn_ctrl.send.assert_called_with(
service, [], second_chunk, [], None, None, DeviceListUpdates()
)
self.assertEqual(2, self.txn_ctrl.send.call_count) self.assertEqual(2, self.txn_ctrl.send.call_count)

View file

@ -15,6 +15,8 @@
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from parameterized import parameterized
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
to_device_messages, to_device_messages,
_otks, _otks,
_fbks, _fbks,
_device_list_summary,
) = self.send_mock.call_args[0] ) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent # Assert that this was the same to-device message that local_user sent
@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {} service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list: for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] (
service,
_events,
_ephemeral,
to_device_messages,
_otks,
_fbks,
_device_list_summary,
) = call[0]
# Check that this was made to an interested service # Check that this was made to an interested service
self.assertIn(service, interested_appservices) self.assertIn(service, interested_appservices)
@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
return appservice return appservice
class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
"""
Tests that the ApplicationServicesHandler sends device list updates to application
services correctly.
"""
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Allow us to modify cached feature flags mid-test
self.as_handler = hs.get_application_service_handler()
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
self.put_json = simple_async_mock()
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock(
return_value=self._services
)
# Test across a variety of configuration values
@parameterized.expand(
[
(True, True, True),
(True, False, False),
(False, True, False),
(False, False, False),
]
)
def test_application_service_receives_device_list_updates(
self,
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
):
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
Arguments above are populated by parameterized.
Args:
as_should_receive_device_list_updates: Whether we expect the AS to receive the
device list changes.
experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
feature is enabled. This feature must be enabled for device lists to ASs to work.
as_supports_txn_extensions: Whether the application service has explicitly registered
to receive information defined by MSC3202 - which includes device list changes.
"""
# Change whether the experimental feature is enabled or disabled before making
# device list changes
self.as_handler._msc3202_transaction_extensions_enabled = (
experimental_feature_enabled
)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@local_user:.+",
"exclusive": False,
}
],
},
supports_ephemeral=True,
msc3202_transaction_extensions=as_supports_txn_extensions,
# Must be set for Synapse to try pushing data to the AS
hs_token="abcde",
url="some_url",
)
# Register the application service
self._services.append(appservice)
# Register a user on the homeserver
self.local_user = self.register_user("local_user", "password")
self.local_user_token = self.login("local_user", "password")
if as_should_receive_device_list_updates:
# Ensure that the resulting JSON uses the unstable prefix and contains the
# expected users
self.put_json.assert_called_once()
json_body = self.put_json.call_args[1]["json_body"]
# Our application service should have received a device list update with
# "local_user" in the "changed" list
device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
self.assertEqual([], device_list_dict["left"])
self.assertEqual([self.local_user], device_list_dict["changed"])
else:
# No device list changes should have been sent out
self.put_json.assert_not_called()
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`. # Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4 ARG_OTK_COUNTS = 4

View file

@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
ApplicationServiceStore, ApplicationServiceStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
) )
from synapse.types import DeviceListUpdates
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success( txn = self.get_success(
defer.ensureDeferred( defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [], {}, {}) self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceListUpdates()
)
) )
) )
self.assertEqual(txn.id, 1) self.assertEqual(txn.id, 1)
@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events)) self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {}) self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceListUpdates()
)
) )
self.assertEqual(txn.id, 9646) self.assertEqual(txn.id, 9646)
self.assertEqual(txn.events, events) self.assertEqual(txn.events, events)
@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643)) self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {}) self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceListUpdates()
)
) )
self.assertEqual(txn.id, 9644) self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events) self.assertEqual(txn.events, events)
@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {}) self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceListUpdates()
)
) )
self.assertEqual(txn.id, 9644) self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events) self.assertEqual(txn.events, events)