0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Handle to-device extensions to Sliding Sync (#17416)

Implements MSC3885

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
Erik Johnston 2024-07-10 11:58:42 +01:00 committed by GitHub
parent 8e229535fa
commit 4ca13ce0dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 392 additions and 12 deletions

View file

@ -0,0 +1 @@
Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View file

@ -542,11 +542,15 @@ class SlidingSyncHandler:
rooms[room_id] = room_sync_result rooms[room_id] = room_sync_result
extensions = await self.get_extensions_response(
sync_config=sync_config, to_token=to_token
)
return SlidingSyncResult( return SlidingSyncResult(
next_pos=to_token, next_pos=to_token,
lists=lists, lists=lists,
rooms=rooms, rooms=rooms,
extensions={}, extensions=extensions,
) )
async def get_sync_room_ids_for_user( async def get_sync_room_ids_for_user(
@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
notification_count=0, notification_count=0,
highlight_count=0, highlight_count=0,
) )
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions:
"""Handle extension requests.
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
"""
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
if sync_config.extensions.to_device:
to_device_response = await self.get_to_device_extensions_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
return SlidingSyncResult.Extensions(to_device=to_device_response)
async def get_to_device_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
"""Handle to-device extension (MSC3885)
Args:
sync_config: Sync configuration
to_device_request: The to-device extension from the request
to_token: The point in the stream to sync up to.
"""
user_id = sync_config.user.to_string()
device_id = sync_config.device_id
# Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the
# extension is enabled.
if device_id is None or not to_device_request.enabled:
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}",
events=[],
)
since_stream_id = 0
if to_device_request.since is not None:
# We've already validated this is an int.
since_stream_id = int(to_device_request.since)
if to_token.to_device_key < since_stream_id:
# The since token is ahead of our current token, so we return an
# empty response.
logger.warning(
"Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
since_stream_id,
to_token.to_device_key,
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=to_device_request.since,
events=[],
)
# Delete everything before the given since token, as we know the
# device must have received them.
deleted = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
up_to_stream_id=since_stream_id,
)
logger.debug(
"Deleted %d to-device messages up to %d for %s",
deleted,
since_stream_id,
user_id,
)
messages, stream_id = await self.store.get_messages_for_device(
user_id=user_id,
device_id=device_id,
from_stream_id=since_stream_id,
to_stream_id=to_token.to_device_key,
limit=min(to_device_request.limit, 100), # Limit to at most 100 events
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{stream_id}",
events=messages,
)

View file

@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
response["rooms"] = await self.encode_rooms( response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms requester, sliding_sync_result.rooms
) )
response["extensions"] = {} # TODO: sliding_sync_result.extensions response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
)
return response return response
@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
return serialized_rooms return serialized_rooms
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
) -> JsonDict:
result = {}
if extensions.to_device is not None:
result["to_device"] = {
"next_batch": extensions.to_device.next_batch,
"events": extensions.to_device.events,
}
return result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View file

@ -18,7 +18,7 @@
# #
# #
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
import attr import attr
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -252,10 +252,39 @@ class SlidingSyncResult:
count: int count: int
ops: List[Operation] ops: List[Operation]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Extensions:
"""Responses for extensions
Attributes:
to_device: The to-device extension (MSC3885)
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ToDeviceExtension:
"""The to-device extension (MSC3885)
Attributes:
next_batch: The to-device stream token the client should use
to get more results
events: A list of to-device messages for the client
"""
next_batch: str
events: Sequence[JsonMapping]
def __bool__(self) -> bool:
return bool(self.events)
to_device: Optional[ToDeviceExtension] = None
def __bool__(self) -> bool:
return bool(self.to_device)
next_pos: StreamToken next_pos: StreamToken
lists: Dict[str, SlidingWindowList] lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult] rooms: Dict[str, RoomResult]
extensions: JsonMapping extensions: Extensions
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
@ -271,5 +300,5 @@ class SlidingSyncResult:
next_pos=next_pos, next_pos=next_pos,
lists={}, lists={},
rooms={}, rooms={},
extensions={}, extensions=SlidingSyncResult.Extensions(),
) )

View file

@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
class RoomSubscription(CommonRoomParameters): class RoomSubscription(CommonRoomParameters):
pass pass
class Extension(RequestBodyModel): class Extensions(RequestBodyModel):
"""The extensions section of the request.
Extensions MUST have an `enabled` flag which defaults to `false`. If a client
sends an unknown extension name, the server MUST ignore it (or else backwards
compatibility between clients and servers is broken when a newer client tries to
communicate with an older server).
"""
class ToDeviceExtension(RequestBodyModel):
"""The to-device extension (MSC3885)
Attributes:
enabled
limit: Maximum number of to-device messages to return
since: The `next_batch` from the previous sync response
"""
enabled: Optional[StrictBool] = False enabled: Optional[StrictBool] = False
lists: Optional[List[StrictStr]] = None limit: StrictInt = 100
rooms: Optional[List[StrictStr]] = None since: Optional[StrictStr] = None
@validator("since")
def since_token_check(
cls, value: Optional[StrictStr]
) -> Optional[StrictStr]:
# `since` comes in as an opaque string token but we know that it's just
# an integer representing the position in the device inbox stream. We
# want to pre-validate it to make sure it works fine in downstream code.
if value is None:
return value
try:
int(value)
except ValueError:
raise ValueError(
"'extensions.to_device.since' is invalid (should look like an int)"
)
return value
to_device: Optional[ToDeviceExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING: if TYPE_CHECKING:
@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
else: else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
extensions: Optional[Dict[StrictStr, Extension]] = None extensions: Optional[Extensions] = None
@validator("lists") @validator("lists")
def lists_length_check( def lists_length_check(

View file

@ -38,7 +38,16 @@ from synapse.api.constants import (
) )
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.sliding_sync import StateValues from synapse.handlers.sliding_sync import StateValues
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync from synapse.rest.client import (
devices,
knock,
login,
read_marker,
receipts,
room,
sendtodevice,
sync,
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.util import Clock from synapse.util import Clock
@ -47,7 +56,7 @@ from tests import unittest
from tests.federation.transport.test_knocking import ( from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin, KnockingStrippedStateEventHelperMixin,
) )
from tests.server import TimedOutException from tests.server import FakeChannel, TimedOutException
from tests.test_utils.event_injection import mark_event_as_partial_state from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
], ],
channel.json_body["lists"]["foo-list"], channel.json_body["lists"]["foo-list"],
) )
class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
"""Tests for the to-device sliding sync extension"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
sendtodevice.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.sync_endpoint = (
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
)
def _assert_to_device_response(
self, channel: FakeChannel, expected_messages: List[JsonDict]
) -> str:
"""Assert the sliding sync response was successful and has the expected
to-device messages.
Returns the next_batch token from the to-device section.
"""
self.assertEqual(channel.code, 200, channel.json_body)
extensions = channel.json_body["extensions"]
to_device = extensions["to_device"]
self.assertIsInstance(to_device["next_batch"], str)
self.assertEqual(to_device["events"], expected_messages)
return to_device["next_batch"]
def test_no_data(self) -> None:
"""Test that enabling to-device extension works, even if there is
no-data
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
# We expect no to-device messages
self._assert_to_device_response(channel, [])
def test_data_initial_sync(self) -> None:
"""Test that we get to-device messages when we don't specify a since
token"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", "d1")
user2_id = self.register_user("u2", "pass")
user2_tok = self.login(user2_id, "pass", "d2")
# Send the to-device message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user1_id: {"d1": test_msg}}},
access_token=user2_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(
channel,
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
)
def test_data_incremental_sync(self) -> None:
"""Test that we get to-device messages over incremental syncs"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", "d1")
user2_id = self.register_user("u2", "pass")
user2_tok = self.login(user2_id, "pass", "d2")
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
# No to-device messages yet.
next_batch = self._assert_to_device_response(channel, [])
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user1_id: {"d1": test_msg}}},
access_token=user2_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
"since": next_batch,
}
},
},
access_token=user1_tok,
)
next_batch = self._assert_to_device_response(
channel,
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
)
# The next sliding sync request should not include the to-device
# message.
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
"since": next_batch,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(channel, [])
# An initial sliding sync request should not include the to-device
# message, as it should have been deleted
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(channel, [])