mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 04:52:26 +01:00
Handle to-device extensions to Sliding Sync (#17416)
Implements MSC3885 --------- Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
parent
8e229535fa
commit
4ca13ce0dd
6 changed files with 392 additions and 12 deletions
1
changelog.d/17416.feature
Normal file
1
changelog.d/17416.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
|
|
@ -542,11 +542,15 @@ class SlidingSyncHandler:
|
||||||
|
|
||||||
rooms[room_id] = room_sync_result
|
rooms[room_id] = room_sync_result
|
||||||
|
|
||||||
|
extensions = await self.get_extensions_response(
|
||||||
|
sync_config=sync_config, to_token=to_token
|
||||||
|
)
|
||||||
|
|
||||||
return SlidingSyncResult(
|
return SlidingSyncResult(
|
||||||
next_pos=to_token,
|
next_pos=to_token,
|
||||||
lists=lists,
|
lists=lists,
|
||||||
rooms=rooms,
|
rooms=rooms,
|
||||||
extensions={},
|
extensions=extensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_sync_room_ids_for_user(
|
async def get_sync_room_ids_for_user(
|
||||||
|
@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
|
||||||
notification_count=0,
|
notification_count=0,
|
||||||
highlight_count=0,
|
highlight_count=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_extensions_response(
|
||||||
|
self,
|
||||||
|
sync_config: SlidingSyncConfig,
|
||||||
|
to_token: StreamToken,
|
||||||
|
) -> SlidingSyncResult.Extensions:
|
||||||
|
"""Handle extension requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_config: Sync configuration
|
||||||
|
to_token: The point in the stream to sync up to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if sync_config.extensions is None:
|
||||||
|
return SlidingSyncResult.Extensions()
|
||||||
|
|
||||||
|
to_device_response = None
|
||||||
|
if sync_config.extensions.to_device:
|
||||||
|
to_device_response = await self.get_to_device_extensions_response(
|
||||||
|
sync_config=sync_config,
|
||||||
|
to_device_request=sync_config.extensions.to_device,
|
||||||
|
to_token=to_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SlidingSyncResult.Extensions(to_device=to_device_response)
|
||||||
|
|
||||||
|
async def get_to_device_extensions_response(
|
||||||
|
self,
|
||||||
|
sync_config: SlidingSyncConfig,
|
||||||
|
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
|
||||||
|
to_token: StreamToken,
|
||||||
|
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
|
||||||
|
"""Handle to-device extension (MSC3885)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_config: Sync configuration
|
||||||
|
to_device_request: The to-device extension from the request
|
||||||
|
to_token: The point in the stream to sync up to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = sync_config.user.to_string()
|
||||||
|
device_id = sync_config.device_id
|
||||||
|
|
||||||
|
# Check that this request has a valid device ID (not all requests have
|
||||||
|
# to belong to a device, and so device_id is None), and that the
|
||||||
|
# extension is enabled.
|
||||||
|
if device_id is None or not to_device_request.enabled:
|
||||||
|
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||||
|
next_batch=f"{to_token.to_device_key}",
|
||||||
|
events=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
since_stream_id = 0
|
||||||
|
if to_device_request.since is not None:
|
||||||
|
# We've already validated this is an int.
|
||||||
|
since_stream_id = int(to_device_request.since)
|
||||||
|
|
||||||
|
if to_token.to_device_key < since_stream_id:
|
||||||
|
# The since token is ahead of our current token, so we return an
|
||||||
|
# empty response.
|
||||||
|
logger.warning(
|
||||||
|
"Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
|
||||||
|
since_stream_id,
|
||||||
|
to_token.to_device_key,
|
||||||
|
)
|
||||||
|
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||||
|
next_batch=to_device_request.since,
|
||||||
|
events=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete everything before the given since token, as we know the
|
||||||
|
# device must have received them.
|
||||||
|
deleted = await self.store.delete_messages_for_device(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
up_to_stream_id=since_stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Deleted %d to-device messages up to %d for %s",
|
||||||
|
deleted,
|
||||||
|
since_stream_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages, stream_id = await self.store.get_messages_for_device(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
from_stream_id=since_stream_id,
|
||||||
|
to_stream_id=to_token.to_device_key,
|
||||||
|
limit=min(to_device_request.limit, 100), # Limit to at most 100 events
|
||||||
|
)
|
||||||
|
|
||||||
|
return SlidingSyncResult.Extensions.ToDeviceExtension(
|
||||||
|
next_batch=f"{stream_id}",
|
||||||
|
events=messages,
|
||||||
|
)
|
||||||
|
|
|
@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
|
||||||
response["rooms"] = await self.encode_rooms(
|
response["rooms"] = await self.encode_rooms(
|
||||||
requester, sliding_sync_result.rooms
|
requester, sliding_sync_result.rooms
|
||||||
)
|
)
|
||||||
response["extensions"] = {} # TODO: sliding_sync_result.extensions
|
response["extensions"] = await self.encode_extensions(
|
||||||
|
requester, sliding_sync_result.extensions
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
|
||||||
|
|
||||||
return serialized_rooms
|
return serialized_rooms
|
||||||
|
|
||||||
|
async def encode_extensions(
|
||||||
|
self, requester: Requester, extensions: SlidingSyncResult.Extensions
|
||||||
|
) -> JsonDict:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if extensions.to_device is not None:
|
||||||
|
result["to_device"] = {
|
||||||
|
"next_batch": extensions.to_device.next_batch,
|
||||||
|
"events": extensions.to_device.events,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
SyncRestServlet(hs).register(http_server)
|
SyncRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
@ -252,10 +252,39 @@ class SlidingSyncResult:
|
||||||
count: int
|
count: int
|
||||||
ops: List[Operation]
|
ops: List[Operation]
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class Extensions:
|
||||||
|
"""Responses for extensions
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
to_device: The to-device extension (MSC3885)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ToDeviceExtension:
|
||||||
|
"""The to-device extension (MSC3885)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
next_batch: The to-device stream token the client should use
|
||||||
|
to get more results
|
||||||
|
events: A list of to-device messages for the client
|
||||||
|
"""
|
||||||
|
|
||||||
|
next_batch: str
|
||||||
|
events: Sequence[JsonMapping]
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.events)
|
||||||
|
|
||||||
|
to_device: Optional[ToDeviceExtension] = None
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.to_device)
|
||||||
|
|
||||||
next_pos: StreamToken
|
next_pos: StreamToken
|
||||||
lists: Dict[str, SlidingWindowList]
|
lists: Dict[str, SlidingWindowList]
|
||||||
rooms: Dict[str, RoomResult]
|
rooms: Dict[str, RoomResult]
|
||||||
extensions: JsonMapping
|
extensions: Extensions
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
"""Make the result appear empty if there are no updates. This is used
|
"""Make the result appear empty if there are no updates. This is used
|
||||||
|
@ -271,5 +300,5 @@ class SlidingSyncResult:
|
||||||
next_pos=next_pos,
|
next_pos=next_pos,
|
||||||
lists={},
|
lists={},
|
||||||
rooms={},
|
rooms={},
|
||||||
extensions={},
|
extensions=SlidingSyncResult.Extensions(),
|
||||||
)
|
)
|
||||||
|
|
|
@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
|
||||||
class RoomSubscription(CommonRoomParameters):
|
class RoomSubscription(CommonRoomParameters):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class Extension(RequestBodyModel):
|
class Extensions(RequestBodyModel):
|
||||||
|
"""The extensions section of the request.
|
||||||
|
|
||||||
|
Extensions MUST have an `enabled` flag which defaults to `false`. If a client
|
||||||
|
sends an unknown extension name, the server MUST ignore it (or else backwards
|
||||||
|
compatibility between clients and servers is broken when a newer client tries to
|
||||||
|
communicate with an older server).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ToDeviceExtension(RequestBodyModel):
|
||||||
|
"""The to-device extension (MSC3885)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled
|
||||||
|
limit: Maximum number of to-device messages to return
|
||||||
|
since: The `next_batch` from the previous sync response
|
||||||
|
"""
|
||||||
|
|
||||||
enabled: Optional[StrictBool] = False
|
enabled: Optional[StrictBool] = False
|
||||||
lists: Optional[List[StrictStr]] = None
|
limit: StrictInt = 100
|
||||||
rooms: Optional[List[StrictStr]] = None
|
since: Optional[StrictStr] = None
|
||||||
|
|
||||||
|
@validator("since")
|
||||||
|
def since_token_check(
|
||||||
|
cls, value: Optional[StrictStr]
|
||||||
|
) -> Optional[StrictStr]:
|
||||||
|
# `since` comes in as an opaque string token but we know that it's just
|
||||||
|
# an integer representing the position in the device inbox stream. We
|
||||||
|
# want to pre-validate it to make sure it works fine in downstream code.
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
int(value)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
"'extensions.to_device.since' is invalid (should look like an int)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
to_device: Optional[ToDeviceExtension] = None
|
||||||
|
|
||||||
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
|
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
|
||||||
else:
|
else:
|
||||||
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
|
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
|
||||||
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
|
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
|
||||||
extensions: Optional[Dict[StrictStr, Extension]] = None
|
extensions: Optional[Extensions] = None
|
||||||
|
|
||||||
@validator("lists")
|
@validator("lists")
|
||||||
def lists_length_check(
|
def lists_length_check(
|
||||||
|
|
|
@ -38,7 +38,16 @@ from synapse.api.constants import (
|
||||||
)
|
)
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.handlers.sliding_sync import StateValues
|
from synapse.handlers.sliding_sync import StateValues
|
||||||
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
|
from synapse.rest.client import (
|
||||||
|
devices,
|
||||||
|
knock,
|
||||||
|
login,
|
||||||
|
read_marker,
|
||||||
|
receipts,
|
||||||
|
room,
|
||||||
|
sendtodevice,
|
||||||
|
sync,
|
||||||
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -47,7 +56,7 @@ from tests import unittest
|
||||||
from tests.federation.transport.test_knocking import (
|
from tests.federation.transport.test_knocking import (
|
||||||
KnockingStrippedStateEventHelperMixin,
|
KnockingStrippedStateEventHelperMixin,
|
||||||
)
|
)
|
||||||
from tests.server import TimedOutException
|
from tests.server import FakeChannel, TimedOutException
|
||||||
from tests.test_utils.event_injection import mark_event_as_partial_state
|
from tests.test_utils.event_injection import mark_event_as_partial_state
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
channel.json_body["lists"]["foo-list"],
|
channel.json_body["lists"]["foo-list"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Tests for the to-device sliding sync extension"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
sendtodevice.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
# Enable sliding sync
|
||||||
|
config["experimental_features"] = {"msc3575_enabled": True}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.sync_endpoint = (
|
||||||
|
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assert_to_device_response(
|
||||||
|
self, channel: FakeChannel, expected_messages: List[JsonDict]
|
||||||
|
) -> str:
|
||||||
|
"""Assert the sliding sync response was successful and has the expected
|
||||||
|
to-device messages.
|
||||||
|
|
||||||
|
Returns the next_batch token from the to-device section.
|
||||||
|
"""
|
||||||
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
extensions = channel.json_body["extensions"]
|
||||||
|
to_device = extensions["to_device"]
|
||||||
|
self.assertIsInstance(to_device["next_batch"], str)
|
||||||
|
self.assertEqual(to_device["events"], expected_messages)
|
||||||
|
|
||||||
|
return to_device["next_batch"]
|
||||||
|
|
||||||
|
def test_no_data(self) -> None:
|
||||||
|
"""Test that enabling to-device extension works, even if there is
|
||||||
|
no-data
|
||||||
|
"""
|
||||||
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
user1_tok = self.login(user1_id, "pass")
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We expect no to-device messages
|
||||||
|
self._assert_to_device_response(channel, [])
|
||||||
|
|
||||||
|
def test_data_initial_sync(self) -> None:
|
||||||
|
"""Test that we get to-device messages when we don't specify a since
|
||||||
|
token"""
|
||||||
|
|
||||||
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
user1_tok = self.login(user1_id, "pass", "d1")
|
||||||
|
user2_id = self.register_user("u2", "pass")
|
||||||
|
user2_tok = self.login(user2_id, "pass", "d2")
|
||||||
|
|
||||||
|
# Send the to-device message
|
||||||
|
test_msg = {"foo": "bar"}
|
||||||
|
chan = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_matrix/client/r0/sendToDevice/m.test/1234",
|
||||||
|
content={"messages": {user1_id: {"d1": test_msg}}},
|
||||||
|
access_token=user2_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
self._assert_to_device_response(
|
||||||
|
channel,
|
||||||
|
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_data_incremental_sync(self) -> None:
|
||||||
|
"""Test that we get to-device messages over incremental syncs"""
|
||||||
|
|
||||||
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
user1_tok = self.login(user1_id, "pass", "d1")
|
||||||
|
user2_id = self.register_user("u2", "pass")
|
||||||
|
user2_tok = self.login(user2_id, "pass", "d2")
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
# No to-device messages yet.
|
||||||
|
next_batch = self._assert_to_device_response(channel, [])
|
||||||
|
|
||||||
|
test_msg = {"foo": "bar"}
|
||||||
|
chan = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_matrix/client/r0/sendToDevice/m.test/1234",
|
||||||
|
content={"messages": {user1_id: {"d1": test_msg}}},
|
||||||
|
access_token=user2_tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(chan.code, 200, chan.result)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
"since": next_batch,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
next_batch = self._assert_to_device_response(
|
||||||
|
channel,
|
||||||
|
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The next sliding sync request should not include the to-device
|
||||||
|
# message.
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
"since": next_batch,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
self._assert_to_device_response(channel, [])
|
||||||
|
|
||||||
|
# An initial sliding sync request should not include the to-device
|
||||||
|
# message, as it should have been deleted
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
self.sync_endpoint,
|
||||||
|
{
|
||||||
|
"lists": {},
|
||||||
|
"extensions": {
|
||||||
|
"to_device": {
|
||||||
|
"enabled": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
access_token=user1_tok,
|
||||||
|
)
|
||||||
|
self._assert_to_device_response(channel, [])
|
||||||
|
|
Loading…
Reference in a new issue