mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 21:13:23 +01:00
Handle to-device extensions to Sliding Sync (#17416)
Implements MSC3885 --------- Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
parent
8e229535fa
commit
4ca13ce0dd
6 changed files with 392 additions and 12 deletions
1
changelog.d/17416.feature
Normal file
1
changelog.d/17416.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
|
|
@ -542,11 +542,15 @@ class SlidingSyncHandler:
|
|||
|
||||
rooms[room_id] = room_sync_result
|
||||
|
||||
extensions = await self.get_extensions_response(
|
||||
sync_config=sync_config, to_token=to_token
|
||||
)
|
||||
|
||||
return SlidingSyncResult(
|
||||
next_pos=to_token,
|
||||
lists=lists,
|
||||
rooms=rooms,
|
||||
extensions={},
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
async def get_sync_room_ids_for_user(
|
||||
|
@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
|
|||
notification_count=0,
|
||||
highlight_count=0,
|
||||
)
|
||||
|
||||
async def get_extensions_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
to_token: StreamToken,
|
||||
) -> SlidingSyncResult.Extensions:
|
||||
"""Handle extension requests.
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration
|
||||
to_token: The point in the stream to sync up to.
|
||||
"""
|
||||
|
||||
if sync_config.extensions is None:
|
||||
return SlidingSyncResult.Extensions()
|
||||
|
||||
to_device_response = None
|
||||
if sync_config.extensions.to_device:
|
||||
to_device_response = await self.get_to_device_extensions_response(
|
||||
sync_config=sync_config,
|
||||
to_device_request=sync_config.extensions.to_device,
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(to_device=to_device_response)
|
||||
|
||||
async def get_to_device_extensions_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
|
||||
to_token: StreamToken,
|
||||
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
|
||||
"""Handle to-device extension (MSC3885)
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration
|
||||
to_device_request: The to-device extension from the request
|
||||
to_token: The point in the stream to sync up to.
|
||||
"""
|
||||
|
||||
user_id = sync_config.user.to_string()
|
||||
device_id = sync_config.device_id
|
||||
|
||||
# Check that this request has a valid device ID (not all requests have
|
||||
# to belong to a device, and so device_id is None), and that the
|
||||
# extension is enabled.
|
||||
if device_id is None or not to_device_request.enabled:
|
||||
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||
next_batch=f"{to_token.to_device_key}",
|
||||
events=[],
|
||||
)
|
||||
|
||||
since_stream_id = 0
|
||||
if to_device_request.since is not None:
|
||||
# We've already validated this is an int.
|
||||
since_stream_id = int(to_device_request.since)
|
||||
|
||||
if to_token.to_device_key < since_stream_id:
|
||||
# The since token is ahead of our current token, so we return an
|
||||
# empty response.
|
||||
logger.warning(
|
||||
"Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
|
||||
since_stream_id,
|
||||
to_token.to_device_key,
|
||||
)
|
||||
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||
next_batch=to_device_request.since,
|
||||
events=[],
|
||||
)
|
||||
|
||||
# Delete everything before the given since token, as we know the
|
||||
# device must have received them.
|
||||
deleted = await self.store.delete_messages_for_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
up_to_stream_id=since_stream_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Deleted %d to-device messages up to %d for %s",
|
||||
deleted,
|
||||
since_stream_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
messages, stream_id = await self.store.get_messages_for_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
from_stream_id=since_stream_id,
|
||||
to_stream_id=to_token.to_device_key,
|
||||
limit=min(to_device_request.limit, 100), # Limit to at most 100 events
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||
next_batch=f"{stream_id}",
|
||||
events=messages,
|
||||
)
|
||||
|
|
|
@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
|
|||
response["rooms"] = await self.encode_rooms(
|
||||
requester, sliding_sync_result.rooms
|
||||
)
|
||||
response["extensions"] = {} # TODO: sliding_sync_result.extensions
|
||||
response["extensions"] = await self.encode_extensions(
|
||||
requester, sliding_sync_result.extensions
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
|
|||
|
||||
return serialized_rooms
|
||||
|
||||
async def encode_extensions(
|
||||
self, requester: Requester, extensions: SlidingSyncResult.Extensions
|
||||
) -> JsonDict:
|
||||
result = {}
|
||||
|
||||
if extensions.to_device is not None:
|
||||
result["to_device"] = {
|
||||
"next_batch": extensions.to_device.next_batch,
|
||||
"events": extensions.to_device.events,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#
|
||||
#
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
|
||||
|
||||
import attr
|
||||
from typing_extensions import TypedDict
|
||||
|
@ -252,10 +252,39 @@ class SlidingSyncResult:
|
|||
count: int
|
||||
ops: List[Operation]
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class Extensions:
|
||||
"""Responses for extensions
|
||||
|
||||
Attributes:
|
||||
to_device: The to-device extension (MSC3885)
|
||||
"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ToDeviceExtension:
|
||||
"""The to-device extension (MSC3885)
|
||||
|
||||
Attributes:
|
||||
next_batch: The to-device stream token the client should use
|
||||
to get more results
|
||||
events: A list of to-device messages for the client
|
||||
"""
|
||||
|
||||
next_batch: str
|
||||
events: Sequence[JsonMapping]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.events)
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.to_device)
|
||||
|
||||
next_pos: StreamToken
|
||||
lists: Dict[str, SlidingWindowList]
|
||||
rooms: Dict[str, RoomResult]
|
||||
extensions: JsonMapping
|
||||
extensions: Extensions
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -271,5 +300,5 @@ class SlidingSyncResult:
|
|||
next_pos=next_pos,
|
||||
lists={},
|
||||
rooms={},
|
||||
extensions={},
|
||||
extensions=SlidingSyncResult.Extensions(),
|
||||
)
|
||||
|
|
|
@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
|
|||
class RoomSubscription(CommonRoomParameters):
|
||||
pass
|
||||
|
||||
class Extension(RequestBodyModel):
|
||||
enabled: Optional[StrictBool] = False
|
||||
lists: Optional[List[StrictStr]] = None
|
||||
rooms: Optional[List[StrictStr]] = None
|
||||
class Extensions(RequestBodyModel):
|
||||
"""The extensions section of the request.
|
||||
|
||||
Extensions MUST have an `enabled` flag which defaults to `false`. If a client
|
||||
sends an unknown extension name, the server MUST ignore it (or else backwards
|
||||
compatibility between clients and servers is broken when a newer client tries to
|
||||
communicate with an older server).
|
||||
"""
|
||||
|
||||
class ToDeviceExtension(RequestBodyModel):
|
||||
"""The to-device extension (MSC3885)
|
||||
|
||||
Attributes:
|
||||
enabled
|
||||
limit: Maximum number of to-device messages to return
|
||||
since: The `next_batch` from the previous sync response
|
||||
"""
|
||||
|
||||
enabled: Optional[StrictBool] = False
|
||||
limit: StrictInt = 100
|
||||
since: Optional[StrictStr] = None
|
||||
|
||||
@validator("since")
|
||||
def since_token_check(
|
||||
cls, value: Optional[StrictStr]
|
||||
) -> Optional[StrictStr]:
|
||||
# `since` comes in as an opaque string token but we know that it's just
|
||||
# an integer representing the position in the device inbox stream. We
|
||||
# want to pre-validate it to make sure it works fine in downstream code.
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"'extensions.to_device.since' is invalid (should look like an int)"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
|
||||
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
|
||||
if TYPE_CHECKING:
|
||||
|
@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
|
|||
else:
|
||||
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
|
||||
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
|
||||
extensions: Optional[Dict[StrictStr, Extension]] = None
|
||||
extensions: Optional[Extensions] = None
|
||||
|
||||
@validator("lists")
|
||||
def lists_length_check(
|
||||
|
|
|
@ -38,7 +38,16 @@ from synapse.api.constants import (
|
|||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.sliding_sync import StateValues
|
||||
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
|
||||
from synapse.rest.client import (
|
||||
devices,
|
||||
knock,
|
||||
login,
|
||||
read_marker,
|
||||
receipts,
|
||||
room,
|
||||
sendtodevice,
|
||||
sync,
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
||||
from synapse.util import Clock
|
||||
|
@ -47,7 +56,7 @@ from tests import unittest
|
|||
from tests.federation.transport.test_knocking import (
|
||||
KnockingStrippedStateEventHelperMixin,
|
||||
)
|
||||
from tests.server import TimedOutException
|
||||
from tests.server import FakeChannel, TimedOutException
|
||||
from tests.test_utils.event_injection import mark_event_as_partial_state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
|
|||
],
|
||||
channel.json_body["lists"]["foo-list"],
|
||||
)
|
||||
|
||||
|
||||
class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests for the to-device sliding sync extension"""
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
sendtodevice.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
# Enable sliding sync
|
||||
config["experimental_features"] = {"msc3575_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.sync_endpoint = (
|
||||
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
|
||||
)
|
||||
|
||||
def _assert_to_device_response(
|
||||
self, channel: FakeChannel, expected_messages: List[JsonDict]
|
||||
) -> str:
|
||||
"""Assert the sliding sync response was successful and has the expected
|
||||
to-device messages.
|
||||
|
||||
Returns the next_batch token from the to-device section.
|
||||
"""
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
extensions = channel.json_body["extensions"]
|
||||
to_device = extensions["to_device"]
|
||||
self.assertIsInstance(to_device["next_batch"], str)
|
||||
self.assertEqual(to_device["events"], expected_messages)
|
||||
|
||||
return to_device["next_batch"]
|
||||
|
||||
def test_no_data(self) -> None:
|
||||
"""Test that enabling to-device extension works, even if there is
|
||||
no-data
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
|
||||
# We expect no to-device messages
|
||||
self._assert_to_device_response(channel, [])
|
||||
|
||||
def test_data_initial_sync(self) -> None:
|
||||
"""Test that we get to-device messages when we don't specify a since
|
||||
token"""
|
||||
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass", "d1")
|
||||
user2_id = self.register_user("u2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass", "d2")
|
||||
|
||||
# Send the to-device message
|
||||
test_msg = {"foo": "bar"}
|
||||
chan = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/sendToDevice/m.test/1234",
|
||||
content={"messages": {user1_id: {"d1": test_msg}}},
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self._assert_to_device_response(
|
||||
channel,
|
||||
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
|
||||
)
|
||||
|
||||
def test_data_incremental_sync(self) -> None:
|
||||
"""Test that we get to-device messages over incremental syncs"""
|
||||
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass", "d1")
|
||||
user2_id = self.register_user("u2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass", "d2")
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
# No to-device messages yet.
|
||||
next_batch = self._assert_to_device_response(channel, [])
|
||||
|
||||
test_msg = {"foo": "bar"}
|
||||
chan = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/sendToDevice/m.test/1234",
|
||||
content={"messages": {user1_id: {"d1": test_msg}}},
|
||||
access_token=user2_tok,
|
||||
)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
"since": next_batch,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
next_batch = self._assert_to_device_response(
|
||||
channel,
|
||||
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
|
||||
)
|
||||
|
||||
# The next sliding sync request should not include the to-device
|
||||
# message.
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
"since": next_batch,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self._assert_to_device_response(channel, [])
|
||||
|
||||
# An initial sliding sync request should not include the to-device
|
||||
# message, as it should have been deleted
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.sync_endpoint,
|
||||
{
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
"to_device": {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self._assert_to_device_response(channel, [])
|
||||
|
|
Loading…
Reference in a new issue