0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-14 14:01:59 +01:00

Sliding Sync: Add E2EE extension (MSC3884) (#17454)

Spec: [MSC3884](https://github.com/matrix-org/matrix-spec-proposals/pull/3884)

Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
This commit is contained in:
Eric Eastwood 2024-07-22 15:40:06 -05:00 committed by GitHub
parent d221512498
commit de05a64246
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1023 additions and 34 deletions

View file

@ -0,0 +1 @@
Add E2EE extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View file

@ -39,6 +39,7 @@ from synapse.metrics.background_process_metrics import (
) )
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import ( from synapse.types import (
DeviceListUpdates,
JsonDict, JsonDict,
JsonMapping, JsonMapping,
ScheduledTask, ScheduledTask,
@ -214,7 +215,7 @@ class DeviceWorkerHandler:
@cancellable @cancellable
async def get_user_ids_changed( async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken self, user_id: str, from_token: StreamToken
) -> JsonDict: ) -> DeviceListUpdates:
"""Get list of users that have had the devices updated, or have newly """Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in. joined a room, that `user_id` may be interested in.
""" """
@ -341,11 +342,19 @@ class DeviceWorkerHandler:
possibly_joined = set() possibly_joined = set()
possibly_left = set() possibly_left = set()
result = {"changed": list(possibly_joined), "left": list(possibly_left)} device_list_updates = DeviceListUpdates(
changed=possibly_joined,
left=possibly_left,
)
log_kv(result) log_kv(
{
"changed": device_list_updates.changed,
"left": device_list_updates.left,
}
)
return result return device_list_updates
async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
if not self.hs.is_mine(UserID.from_string(user_id)): if not self.hs.is_mine(UserID.from_string(user_id)):

View file

@ -19,7 +19,18 @@
# #
import logging import logging
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr import attr
from immutabledict import immutabledict from immutabledict import immutabledict
@ -33,6 +44,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_s
from synapse.storage.databases.main.stream import CurrentStateDeltaMembership from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.types import ( from synapse.types import (
DeviceListUpdates,
JsonDict, JsonDict,
PersistedEventPosition, PersistedEventPosition,
Requester, Requester,
@ -343,6 +355,7 @@ class SlidingSyncHandler:
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler() self.relations_handler = hs.get_relations_handler()
self.device_handler = hs.get_device_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
@ -371,10 +384,6 @@ class SlidingSyncHandler:
# auth_blocking will occur) # auth_blocking will occur)
await self.auth_blocking.check_auth_blocking(requester=requester) await self.auth_blocking.check_auth_blocking(requester=requester)
# TODO: If the To-Device extension is enabled and we have a `from_token`, delete
# any to-device messages before that token (since we now know that the device
# has received them). (see sync v2 for how to do this)
# If we're working with a user-provided token, we need to make sure to wait for # If we're working with a user-provided token, we need to make sure to wait for
# this worker to catch up with the token so we don't skip past any incoming # this worker to catch up with the token so we don't skip past any incoming
# events or future events if the user is nefariously, manually modifying the # events or future events if the user is nefariously, manually modifying the
@ -617,7 +626,9 @@ class SlidingSyncHandler:
await concurrently_execute(handle_room, relevant_room_map, 10) await concurrently_execute(handle_room, relevant_room_map, 10)
extensions = await self.get_extensions_response( extensions = await self.get_extensions_response(
sync_config=sync_config, to_token=to_token sync_config=sync_config,
from_token=from_token,
to_token=to_token,
) )
return SlidingSyncResult( return SlidingSyncResult(
@ -1776,33 +1787,47 @@ class SlidingSyncHandler:
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
to_token: StreamToken, to_token: StreamToken,
from_token: Optional[StreamToken],
) -> SlidingSyncResult.Extensions: ) -> SlidingSyncResult.Extensions:
"""Handle extension requests. """Handle extension requests.
Args: Args:
sync_config: Sync configuration sync_config: Sync configuration
to_token: The point in the stream to sync up to. to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
""" """
if sync_config.extensions is None: if sync_config.extensions is None:
return SlidingSyncResult.Extensions() return SlidingSyncResult.Extensions()
to_device_response = None to_device_response = None
if sync_config.extensions.to_device: if sync_config.extensions.to_device is not None:
to_device_response = await self.get_to_device_extensions_response( to_device_response = await self.get_to_device_extension_response(
sync_config=sync_config, sync_config=sync_config,
to_device_request=sync_config.extensions.to_device, to_device_request=sync_config.extensions.to_device,
to_token=to_token, to_token=to_token,
) )
return SlidingSyncResult.Extensions(to_device=to_device_response) e2ee_response = None
if sync_config.extensions.e2ee is not None:
e2ee_response = await self.get_e2ee_extension_response(
sync_config=sync_config,
e2ee_request=sync_config.extensions.e2ee,
to_token=to_token,
from_token=from_token,
)
async def get_to_device_extensions_response( return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,
)
async def get_to_device_extension_response(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken, to_token: StreamToken,
) -> SlidingSyncResult.Extensions.ToDeviceExtension: ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
"""Handle to-device extension (MSC3885) """Handle to-device extension (MSC3885)
Args: Args:
@ -1810,14 +1835,16 @@ class SlidingSyncHandler:
to_device_request: The to-device extension from the request to_device_request: The to-device extension from the request
to_token: The point in the stream to sync up to. to_token: The point in the stream to sync up to.
""" """
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
device_id = sync_config.device_id device_id = sync_config.device_id
# Skip if the extension is not enabled
if not to_device_request.enabled:
return None
# Check that this request has a valid device ID (not all requests have # Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the # to belong to a device, and so device_id is None)
# extension is enabled. if device_id is None:
if device_id is None or not to_device_request.enabled:
return SlidingSyncResult.Extensions.ToDeviceExtension( return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}", next_batch=f"{to_token.to_device_key}",
events=[], events=[],
@ -1868,3 +1895,53 @@ class SlidingSyncHandler:
next_batch=f"{stream_id}", next_batch=f"{stream_id}",
events=messages, events=messages,
) )
async def get_e2ee_extension_response(
self,
sync_config: SlidingSyncConfig,
e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
to_token: StreamToken,
from_token: Optional[StreamToken],
) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
"""Handle E2EE device extension (MSC3884)
Args:
sync_config: Sync configuration
e2ee_request: The e2ee extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
user_id = sync_config.user.to_string()
device_id = sync_config.device_id
# Skip if the extension is not enabled
if not e2ee_request.enabled:
return None
device_list_updates: Optional[DeviceListUpdates] = None
if from_token is not None:
# TODO: This should take into account the `from_token` and `to_token`
device_list_updates = await self.device_handler.get_user_ids_changed(
user_id=user_id,
from_token=from_token,
)
device_one_time_keys_count: Mapping[str, int] = {}
device_unused_fallback_key_types: Sequence[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
device_unused_fallback_key_types = (
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
return SlidingSyncResult.Extensions.E2eeExtension(
device_list_updates=device_list_updates,
device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=device_unused_fallback_key_types,
)

View file

@ -256,9 +256,15 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
results = await self.device_handler.get_user_ids_changed(user_id, from_token) device_list_updates = await self.device_handler.get_user_ids_changed(
user_id, from_token
)
return 200, results response: JsonDict = {}
response["changed"] = list(device_list_updates.changed)
response["left"] = list(device_list_updates.left)
return 200, response
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):

View file

@ -1081,15 +1081,41 @@ class SlidingSyncRestServlet(RestServlet):
async def encode_extensions( async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions self, requester: Requester, extensions: SlidingSyncResult.Extensions
) -> JsonDict: ) -> JsonDict:
result = {} serialized_extensions: JsonDict = {}
if extensions.to_device is not None: if extensions.to_device is not None:
result["to_device"] = { serialized_extensions["to_device"] = {
"next_batch": extensions.to_device.next_batch, "next_batch": extensions.to_device.next_batch,
"events": extensions.to_device.events, "events": extensions.to_device.events,
} }
return result if extensions.e2ee is not None:
serialized_extensions["e2ee"] = {
# We always include this because
# https://github.com/vector-im/element-android/issues/3725. The spec
# isn't terribly clear on when this can be omitted and how a client
# would tell the difference between "no keys present" and "nothing
# changed" in terms of whole field absent / individual key type entry
# absent Corresponding synapse issue:
# https://github.com/matrix-org/synapse/issues/10456
"device_one_time_keys_count": extensions.e2ee.device_one_time_keys_count,
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the
# server supports the feature.
"device_unused_fallback_key_types": extensions.e2ee.device_unused_fallback_key_types,
}
if extensions.e2ee.device_list_updates is not None:
serialized_extensions["e2ee"]["device_lists"] = {}
serialized_extensions["e2ee"]["device_lists"]["changed"] = list(
extensions.e2ee.device_list_updates.changed
)
serialized_extensions["e2ee"]["device_lists"]["left"] = list(
extensions.e2ee.device_list_updates.left
)
return serialized_extensions
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View file

@ -1219,11 +1219,12 @@ class ReadReceipt:
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListUpdates: class DeviceListUpdates:
""" """
An object containing a diff of information regarding other users' device lists, intended for An object containing a diff of information regarding other users' device lists,
a recipient to carry out device list tracking. intended for a recipient to carry out device list tracking.
Attributes: Attributes:
changed: A set of users whose device lists have changed recently. changed: A set of users who have updated their device identity or
cross-signing keys, or who now share an encrypted room with.
left: A set of users who the recipient no longer needs to track the device lists of. left: A set of users who the recipient no longer needs to track the device lists of.
Typically when those users no longer share any end-to-end encryption enabled rooms. Typically when those users no longer share any end-to-end encryption enabled rooms.
""" """

View file

@ -18,7 +18,7 @@
# #
# #
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple
import attr import attr
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -31,7 +31,7 @@ else:
from pydantic import Extra from pydantic import Extra
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, JsonMapping, StreamToken, UserID from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID
from synapse.types.rest.client import SlidingSyncBody from synapse.types.rest.client import SlidingSyncBody
if TYPE_CHECKING: if TYPE_CHECKING:
@ -264,6 +264,7 @@ class SlidingSyncResult:
Attributes: Attributes:
to_device: The to-device extension (MSC3885) to_device: The to-device extension (MSC3885)
e2ee: The E2EE device extension (MSC3884)
""" """
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -282,10 +283,51 @@ class SlidingSyncResult:
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.events) return bool(self.events)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeExtension:
"""The E2EE device extension (MSC3884)
Attributes:
device_list_updates: List of user_ids whose devices have changed or left (only
present on incremental syncs).
device_one_time_keys_count: Map from key algorithm to the number of
unclaimed one-time keys currently held on the server for this device. If
an algorithm is unlisted, the count for that algorithm is assumed to be
zero. If this entire parameter is missing, the count for all algorithms
is assumed to be zero.
device_unused_fallback_key_types: List of unused fallback key algorithms
for this device.
"""
# Only present on incremental syncs
device_list_updates: Optional[DeviceListUpdates]
device_one_time_keys_count: Mapping[str, int]
device_unused_fallback_key_types: Sequence[str]
def __bool__(self) -> bool:
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
default_otk = self.device_one_time_keys_count.get("signed_curve25519")
more_than_default_otk = len(self.device_one_time_keys_count) > 1 or (
default_otk is not None and default_otk > 0
)
return bool(
more_than_default_otk
or self.device_list_updates
or self.device_unused_fallback_key_types
)
to_device: Optional[ToDeviceExtension] = None to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.to_device) return bool(self.to_device or self.e2ee)
next_pos: StreamToken next_pos: StreamToken
lists: Dict[str, SlidingWindowList] lists: Dict[str, SlidingWindowList]

View file

@ -313,7 +313,17 @@ class SlidingSyncBody(RequestBodyModel):
return value return value
class E2eeExtension(RequestBodyModel):
"""The E2EE device extension (MSC3884)
Attributes:
enabled
"""
enabled: Optional[StrictBool] = False
to_device: Optional[ToDeviceExtension] = None to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -59,6 +59,7 @@ from tests.federation.transport.test_knocking import (
) )
from tests.server import FakeChannel, TimedOutException from tests.server import FakeChannel, TimedOutException
from tests.test_utils.event_injection import mark_event_as_partial_state from tests.test_utils.event_injection import mark_event_as_partial_state
from tests.unittest import skip_unless
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1113,12 +1114,11 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, []) self.assertEqual(res, [])
# Upload a fallback key for the user/device # Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success( self.get_success(
self.e2e_keys_handler.upload_keys_for_user( self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, alice_user_id,
test_device_id, test_device_id,
{"fallback_keys": fallback_key}, {"fallback_keys": {"alg1:k1": "fallback_key1"}},
) )
) )
# We should now have an unused alg1 key # We should now have an unused alg1 key
@ -1252,6 +1252,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers() self.storage_controllers = hs.get_storage_controllers()
self.account_data_handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
def _assertRequiredStateIncludes( def _assertRequiredStateIncludes(
self, self,
@ -1377,6 +1379,52 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
return room_id return room_id
def _bump_notifier_wait_for_events(self, user_id: str) -> None:
"""
Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
Sync results.
"""
# We're expecting some new activity from this point onwards
from_token = self.event_sources.get_current_token()
triggered_notifier_wait_for_events = False
async def _on_new_acivity(
before_token: StreamToken, after_token: StreamToken
) -> bool:
nonlocal triggered_notifier_wait_for_events
triggered_notifier_wait_for_events = True
return True
# Listen for some new activity for the user. We're just trying to confirm that
# our bump below actually does what we think it does (triggers new activity for
# the user).
result_awaitable = self.notifier.wait_for_events(
user_id,
1000,
_on_new_acivity,
from_token=from_token,
)
# Update the account data so that `notifier.wait_for_events(...)` wakes up.
# We're bumping account data because it won't show up in the Sliding Sync
# response so it won't affect whether we have results.
self.get_success(
self.account_data_handler.add_account_data_for_user(
user_id,
"org.matrix.foobarbaz",
{"foo": "bar"},
)
)
# Wait for our notifier result
self.get_success(result_awaitable)
if not triggered_notifier_wait_for_events:
raise AssertionError(
"Expected `notifier.wait_for_events(...)` to be triggered"
)
def test_sync_list(self) -> None: def test_sync_list(self) -> None:
""" """
Test that room IDs show up in the Sliding Sync `lists` Test that room IDs show up in the Sliding Sync `lists`
@ -1482,6 +1530,124 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# with because we weren't able to find anything new yet. # with because we weren't able to find anything new yet.
self.assertEqual(channel.json_body["pos"], future_position_token_serialized) self.assertEqual(channel.json_body["pos"], future_position_token_serialized)
def test_wait_for_new_data(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
(Only applies to incremental syncs with a `timeout` specified)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {
"foo-list": {
"ranges": [[0, 0]],
"required_state": [],
"timeline_limit": 1,
}
}
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Bump the room with new events to trigger new results
event_response1 = self.helper.send(
room_id, "new activity in room", tok=user1_tok
)
# Should respond before the 10 second timeout
channel.await_result(timeout_ms=3000)
self.assertEqual(channel.code, 200, channel.json_body)
# Check to make sure the new event is returned
self.assertEqual(
[
event["event_id"]
for event in channel.json_body["rooms"][room_id]["timeline"]
],
[
event_response1["event_id"],
],
channel.json_body["rooms"][room_id]["timeline"],
)
# TODO: Once we remove `ops`, we should be able to add a `RoomResult.__bool__` to
# check if there are any updates since the `from_token`.
@skip_unless(
False,
"Once we remove ops from the Sliding Sync response, this test should pass",
)
def test_wait_for_new_data_timeout(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive but
no data ever arrives so we timeout. We're also making sure that the default data
doesn't trigger a false-positive for new data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {
"foo-list": {
"ranges": [[0, 0]],
"required_state": [],
"timeline_limit": 1,
}
}
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Wake-up `notifier.wait_for_events(...)` that will cause us test
# `SlidingSyncResult.__bool__` for new results.
self._bump_notifier_wait_for_events(user1_id)
# Block for a little bit more to ensure we don't see any new results.
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=4000)
# Wait for the sync to complete (wait for the rest of the 10 second timeout,
# 5000 + 4000 + 1200 > 10000)
channel.await_result(timeout_ms=1200)
self.assertEqual(channel.code, 200, channel.json_body)
# We still see rooms because that's how Sliding Sync lists work but we reached
# the timeout before seeing them
self.assertEqual(
[event["event_id"] for event in channel.json_body["rooms"].keys()],
[room_id],
)
def test_filter_list(self) -> None: def test_filter_list(self) -> None:
""" """
Test that filters apply to `lists` Test that filters apply to `lists`
@ -1508,11 +1674,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
) )
# Create a normal room # Create a normal room
room_id = self.helper.create_room_as(user1_id, tok=user2_tok) room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok) self.helper.join(room_id, user1_id, tok=user1_tok)
# Create a room that user1 is invited to # Create a room that user1 is invited to
invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok) invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# Make the Sliding Sync request # Make the Sliding Sync request
@ -4320,10 +4486,59 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.account_data_handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
self.sync_endpoint = ( self.sync_endpoint = (
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
) )
def _bump_notifier_wait_for_events(self, user_id: str) -> None:
"""
Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
Sync results.
"""
# We're expecting some new activity from this point onwards
from_token = self.event_sources.get_current_token()
triggered_notifier_wait_for_events = False
async def _on_new_acivity(
before_token: StreamToken, after_token: StreamToken
) -> bool:
nonlocal triggered_notifier_wait_for_events
triggered_notifier_wait_for_events = True
return True
# Listen for some new activity for the user. We're just trying to confirm that
# our bump below actually does what we think it does (triggers new activity for
# the user).
result_awaitable = self.notifier.wait_for_events(
user_id,
1000,
_on_new_acivity,
from_token=from_token,
)
# Update the account data so that `notifier.wait_for_events(...)` wakes up.
# We're bumping account data because it won't show up in the Sliding Sync
# response so it won't affect whether we have results.
self.get_success(
self.account_data_handler.add_account_data_for_user(
user_id,
"org.matrix.foobarbaz",
{"foo": "bar"},
)
)
# Wait for our notifier result
self.get_success(result_awaitable)
if not triggered_notifier_wait_for_events:
raise AssertionError(
"Expected `notifier.wait_for_events(...)` to be triggered"
)
def _assert_to_device_response( def _assert_to_device_response(
self, channel: FakeChannel, expected_messages: List[JsonDict] self, channel: FakeChannel, expected_messages: List[JsonDict]
) -> str: ) -> str:
@ -4487,3 +4702,605 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
access_token=user1_tok, access_token=user1_tok,
) )
self._assert_to_device_response(channel, []) self._assert_to_device_response(channel, [])
def test_wait_for_new_data(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
(Only applies to incremental syncs with a `timeout` specified)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", "d1")
user2_id = self.register_user("u2", "pass")
user2_tok = self.login(user2_id, "pass", "d2")
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Bump the to-device messages to trigger new results
test_msg = {"foo": "bar"}
send_to_device_channel = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user1_id: {"d1": test_msg}}},
access_token=user2_tok,
)
self.assertEqual(
send_to_device_channel.code, 200, send_to_device_channel.result
)
# Should respond before the 10 second timeout
channel.await_result(timeout_ms=3000)
self.assertEqual(channel.code, 200, channel.json_body)
self._assert_to_device_response(
channel,
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
)
def test_wait_for_new_data_timeout(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive but
no data ever arrives so we timeout. We're also making sure that the default data
from the To-Device extension doesn't trigger a false-positive for new data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Wake-up `notifier.wait_for_events(...)` that will cause us test
# `SlidingSyncResult.__bool__` for new results.
self._bump_notifier_wait_for_events(user1_id)
# Block for a little bit more to ensure we don't see any new results.
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=4000)
# Wait for the sync to complete (wait for the rest of the 10 second timeout,
# 5000 + 4000 + 1200 > 10000)
channel.await_result(timeout_ms=1200)
self.assertEqual(channel.code, 200, channel.json_body)
self._assert_to_device_response(channel, [])
class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
"""Tests for the e2ee sliding sync extension"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.account_data_handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
self.sync_endpoint = (
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
)
def _bump_notifier_wait_for_events(self, user_id: str) -> None:
"""
Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
Sync results.
"""
# We're expecting some new activity from this point onwards
from_token = self.event_sources.get_current_token()
triggered_notifier_wait_for_events = False
async def _on_new_acivity(
before_token: StreamToken, after_token: StreamToken
) -> bool:
nonlocal triggered_notifier_wait_for_events
triggered_notifier_wait_for_events = True
return True
# Listen for some new activity for the user. We're just trying to confirm that
# our bump below actually does what we think it does (triggers new activity for
# the user).
result_awaitable = self.notifier.wait_for_events(
user_id,
1000,
_on_new_acivity,
from_token=from_token,
)
# Update the account data so that `notifier.wait_for_events(...)` wakes up.
# We're bumping account data because it won't show up in the Sliding Sync
# response so it won't affect whether we have results.
self.get_success(
self.account_data_handler.add_account_data_for_user(
user_id,
"org.matrix.foobarbaz",
{"foo": "bar"},
)
)
# Wait for our notifier result
self.get_success(result_awaitable)
if not triggered_notifier_wait_for_events:
raise AssertionError(
"Expected `notifier.wait_for_events(...)` to be triggered"
)
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling e2ee extension works during an intitial sync, even if there
is no-data
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Make an initial Sliding Sync request with the e2ee extension enabled
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Device list updates are only present for incremental syncs
self.assertIsNone(channel.json_body["extensions"]["e2ee"].get("device_lists"))
# Both of these should be present even when empty
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
{
# This is always present because of
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0
},
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
[],
)
def test_no_data_incremental_sync(self) -> None:
"""
Test that enabling e2ee extension works during an incremental sync, even if
there is no-data
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
from_token = self.event_sources.get_current_token()
# Make an incremental Sliding Sync request with the e2ee extension enabled
channel = self.make_request(
"POST",
self.sync_endpoint
+ f"?pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Device list shows up for incremental syncs
self.assertEqual(
channel.json_body["extensions"]["e2ee"]
.get("device_lists", {})
.get("changed"),
[],
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
[],
)
# Both of these should be present even when empty
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
{
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0
},
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
[],
)
def test_wait_for_new_data(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
(Only applies to incremental syncs with a `timeout` specified)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
test_device_id = "TESTDEVICE"
user3_id = self.register_user("user3", "pass")
user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
self.helper.join(room_id, user3_id, tok=user3_tok)
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Bump the device lists to trigger new results
# Have user3 update their device list
device_update_channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=user3_tok,
)
self.assertEqual(
device_update_channel.code, 200, device_update_channel.json_body
)
# Should respond before the 10 second timeout
channel.await_result(timeout_ms=3000)
self.assertEqual(channel.code, 200, channel.json_body)
# We should see the device list update
self.assertEqual(
channel.json_body["extensions"]["e2ee"]
.get("device_lists", {})
.get("changed"),
[user3_id],
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
[],
)
def test_wait_for_new_data_timeout(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive but
no data ever arrives so we timeout. We're also making sure that the default data
from the E2EE extension doesn't trigger a false-positive for new data (see
`device_one_time_keys_count.signed_curve25519`).
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
from_token = self.event_sources.get_current_token()
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint
+ "?timeout=10000"
+ f"&pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
await_result=False,
)
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
# Wake-up `notifier.wait_for_events(...)` that will cause us test
# `SlidingSyncResult.__bool__` for new results.
self._bump_notifier_wait_for_events(user1_id)
# Block for a little bit more to ensure we don't see any new results.
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=4000)
# Wait for the sync to complete (wait for the rest of the 10 second timeout,
# 5000 + 4000 + 1200 > 10000)
channel.await_result(timeout_ms=1200)
self.assertEqual(channel.code, 200, channel.json_body)
# Device lists are present for incremental syncs but empty because no device changes
self.assertEqual(
channel.json_body["extensions"]["e2ee"]
.get("device_lists", {})
.get("changed"),
[],
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
[],
)
# Both of these should be present even when empty
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
{
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0
},
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
[],
)
def test_device_lists(self) -> None:
"""
Test that device list updates are included in the response
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
test_device_id = "TESTDEVICE"
user3_id = self.register_user("user3", "pass")
user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
user4_id = self.register_user("user4", "pass")
user4_tok = self.login(user4_id, "pass")
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
self.helper.join(room_id, user3_id, tok=user3_tok)
self.helper.join(room_id, user4_id, tok=user4_tok)
from_token = self.event_sources.get_current_token()
# Have user3 update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=user3_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# User4 leaves the room
self.helper.leave(room_id, user4_id, tok=user4_tok)
# Make an incremental Sliding Sync request with the e2ee extension enabled
channel = self.make_request(
"POST",
self.sync_endpoint
+ f"?pos={self.get_success(from_token.to_string(self.store))}",
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Device list updates show up
self.assertEqual(
channel.json_body["extensions"]["e2ee"]
.get("device_lists", {})
.get("changed"),
[user3_id],
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
[user4_id],
)
def test_device_one_time_keys_count(self) -> None:
"""
Test that `device_one_time_keys_count` are included in the response
"""
test_device_id = "TESTDEVICE"
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
upload_keys_response = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
user1_id, test_device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(
upload_keys_response,
{
"one_time_key_counts": {
"alg1": 1,
"alg2": 2,
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0,
}
},
)
# Make a Sliding Sync request with the e2ee extension enabled
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_one_time_keys_count"),
{
"alg1": 1,
"alg2": 2,
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0,
},
)
def test_device_unused_fallback_key_types(self) -> None:
"""
Test that `device_unused_fallback_key_types` are included in the response
"""
test_device_id = "TESTDEVICE"
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
user1_id,
test_device_id,
{"fallback_keys": {"alg1:k1": "fallback_key1"}},
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Make a Sliding Sync request with the e2ee extension enabled
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"e2ee": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["extensions"]["e2ee"].get(
"device_unused_fallback_key_types"
),
["alg1"],
)