From 361bdafb877e4303497591a1612aa02f0b00c472 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Nov 2024 14:45:57 +0000 Subject: [PATCH] Add experimental support for MSC4222 (#17888) Basically, if the client sets a special query param on `/sync` v2 instead of responding with `state` at the *start* of the timeline, we instead respond with `state_after` at the *end* of the timeline. We do this by using the `current_state_delta_stream` table, which is actually reliable, rather than messing around with "state at" points on the timeline. c.f. MSC4222 --- changelog.d/17888.feature | 1 + docs/admin_api/experimental_features.md | 1 + synapse/config/experimental.py | 3 + synapse/handlers/sync.py | 128 +++++++++- synapse/rest/admin/experimental_features.py | 3 + synapse/rest/client/sync.py | 45 +++- tests/handlers/test_sync.py | 270 ++++++++++++++++---- 7 files changed, 395 insertions(+), 56 deletions(-) create mode 100644 changelog.d/17888.feature diff --git a/changelog.d/17888.feature b/changelog.d/17888.feature new file mode 100644 index 000000000..3ede8886a --- /dev/null +++ b/changelog.d/17888.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222). diff --git a/docs/admin_api/experimental_features.md b/docs/admin_api/experimental_features.md index ef1b58c9b..e32728e56 100644 --- a/docs/admin_api/experimental_features.md +++ b/docs/admin_api/experimental_features.md @@ -5,6 +5,7 @@ basis. The currently supported features are: - [MSC3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications for another client - [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): enable experimental sliding sync support +- [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222): adding `state_after` to sync v2 To use it, you will need to authenticate by providing an `access_token` for a server admin: see [Admin API](../usage/administration/admin_api/). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index fd14db0d0..b26ce25d7 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -450,3 +450,6 @@ class ExperimentalConfig(Config): # MSC4210: Remove legacy mentions self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False) + + # MSC4222: Adding `state_after` to sync v2 + self.msc4222_enabled: bool = experimental.get("msc4222_enabled", False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f4ea90fbd..df9a08806 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -143,6 +143,7 @@ class SyncConfig: filter_collection: FilterCollection is_guest: bool device_id: Optional[str] + use_state_after: bool @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1141,6 +1142,7 @@ class SyncHandler: since_token: Optional[StreamToken], end_token: StreamToken, full_state: bool, + joined: bool, ) -> MutableStateMap[EventBase]: """Works out the difference in state between the end of the previous sync and the start of the timeline. @@ -1155,6 +1157,7 @@ class SyncHandler: the point just after their leave event. full_state: Whether to force returning the full state. `lazy_load_members` still applies when `full_state` is `True`. + joined: whether the user is currently joined to the room Returns: The state to return in the sync response for the room. @@ -1230,11 +1233,12 @@ class SyncHandler: if full_state: state_ids = await self._compute_state_delta_for_full_sync( room_id, - sync_config.user, + sync_config, batch, end_token, members_to_fetch, timeline_state, + joined, ) else: # If this is an initial sync then full_state should be set, and @@ -1244,6 +1248,7 @@ class SyncHandler: state_ids = await self._compute_state_delta_for_incremental_sync( room_id, + sync_config, batch, since_token, end_token, @@ -1316,20 +1321,24 @@ class SyncHandler: async def _compute_state_delta_for_full_sync( self, room_id: str, - syncing_user: UserID, + sync_config: SyncConfig, batch: TimelineBatch, end_token: StreamToken, members_to_fetch: Optional[Set[str]], timeline_state: StateMap[str], + joined: bool, ) -> StateMap[str]: """Calculate the state events to be included in a full sync response. As with `_compute_state_delta_for_incremental_sync`, the result will include the membership events for the senders of each event in `members_to_fetch`. + Note that whether this returns the state at the start or the end of the + batch depends on `sync_config.use_state_after` (c.f. MSC4222). + Args: room_id: The room we are calculating for. - syncing_user: The user that is calling `/sync`. + sync_confg: The user that is calling `/sync`. batch: The timeline batch for the room that will be sent to the user. end_token: Token of the end of the current batch. Normally this will be the same as the global "now_token", but if the user has left the room, @@ -1338,10 +1347,11 @@ class SyncHandler: events in the timeline. timeline_state: The contribution to the room state from state events in `batch`. Only contains the last event for any given state key. + joined: whether the user is currently joined to the room Returns: A map from (type, state_key) to event_id, for each event that we believe - should be included in the `state` part of the sync response. + should be included in the `state` or `state_after` part of the sync response. """ if members_to_fetch is not None: # Lazy-loading of membership events is enabled. @@ -1359,7 +1369,7 @@ class SyncHandler: # is no guarantee that our membership will be in the auth events of # timeline events when the room is partial stated. state_filter = StateFilter.from_lazy_load_member_list( - members_to_fetch.union((syncing_user.to_string(),)) + members_to_fetch.union((sync_config.user.to_string(),)) ) # We are happy to use partial state to compute the `/sync` response. @@ -1373,6 +1383,61 @@ class SyncHandler: await_full_state = True lazy_load_members = False + # Check if we are wanting to return the state at the start or end of the + # timeline. If at the end we can just use the current state. + if sync_config.use_state_after: + # If we're getting the state at the end of the timeline, we can just + # use the current state of the room (and roll back any changes + # between when we fetched the current state and `end_token`). + # + # For rooms we're not joined to, there might be a very large number + # of deltas between `end_token` and "now", and so instead we fetch + # the state at the end of the timeline. + if joined: + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + # Now roll back the state by looking at the state deltas between + # end_token and now. + deltas = await self.store.get_current_state_deltas_for_room( + room_id, + from_token=end_token.room_key, + to_token=self.store.get_room_max_token(), + ) + if deltas: + mutable_state_ids = dict(state_ids) + + # We iterate over the deltas backwards so that if there are + # multiple changes of the same type/state_key we'll + # correctly pick the earliest delta. + for delta in reversed(deltas): + if delta.prev_event_id: + mutable_state_ids[(delta.event_type, delta.state_key)] = ( + delta.prev_event_id + ) + elif (delta.event_type, delta.state_key) in mutable_state_ids: + mutable_state_ids.pop((delta.event_type, delta.state_key)) + + state_ids = mutable_state_ids + + return state_ids + + else: + # Just use state groups to get the state at the end of the + # timeline, i.e. the state at the leave/etc event. + state_at_timeline_end = ( + await self._state_storage_controller.get_state_ids_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + return state_at_timeline_end + state_at_timeline_end = await self._state_storage_controller.get_state_ids_at( room_id, stream_position=end_token, @@ -1405,6 +1470,7 @@ class SyncHandler: async def _compute_state_delta_for_incremental_sync( self, room_id: str, + sync_config: SyncConfig, batch: TimelineBatch, since_token: StreamToken, end_token: StreamToken, @@ -1419,8 +1485,12 @@ class SyncHandler: (`compute_state_delta`) is responsible for keeping track of which membership events we have already sent to the client, and hence ripping them out. + Note that whether this returns the state at the start or the end of the + batch depends on `sync_config.use_state_after` (c.f. MSC4222). + Args: room_id: The room we are calculating for. + sync_config batch: The timeline batch for the room that will be sent to the user. since_token: Token of the end of the previous batch. end_token: Token of the end of the current batch. Normally this will be @@ -1433,7 +1503,7 @@ class SyncHandler: Returns: A map from (type, state_key) to event_id, for each event that we believe - should be included in the `state` part of the sync response. + should be included in the `state` or `state_after` part of the sync response. """ if members_to_fetch is not None: # Lazy-loading is enabled. Only return the state that is needed. @@ -1445,6 +1515,51 @@ class SyncHandler: await_full_state = True lazy_load_members = False + # Check if we are wanting to return the state at the start or end of the + # timeline. If at the end we can just use the current state delta stream. + if sync_config.use_state_after: + delta_state_ids: MutableStateMap[str] = {} + + if members_to_fetch is not None: + # We're lazy-loading, so the client might need some more member + # events to understand the events in this timeline. So we always + # fish out all the member events corresponding to the timeline + # here. The caller will then dedupe any redundant ones. + member_ids = await self._state_storage_controller.get_current_state_ids( + room_id=room_id, + state_filter=StateFilter.from_types( + (EventTypes.Member, member) for member in members_to_fetch + ), + await_full_state=await_full_state, + ) + delta_state_ids.update(member_ids) + + # We don't do LL filtering for incremental syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way more + # state in the server than if we were LLing. + # + # i.e. we return all state deltas, including membership changes that + # we'd normally exclude due to LL. + deltas = await self.store.get_current_state_deltas_for_room( + room_id=room_id, + from_token=since_token.room_key, + to_token=end_token.room_key, + ) + for delta in deltas: + if delta.event_id is None: + # There was a state reset and this state entry is no longer + # present, but we have no way of informing the client about + # this, so we just skip it for now. + continue + + # Note that deltas are in stream ordering, so if there are + # multiple deltas for a given type/state_key we'll always pick + # the latest one. + delta_state_ids[(delta.event_type, delta.state_key)] = delta.event_id + + return delta_state_ids + # For a non-gappy sync if the events in the timeline are simply a linear # chain (i.e. no merging/branching of the graph), then we know the state # delta between the end of the previous sync and start of the new one is @@ -2867,6 +2982,7 @@ class SyncHandler: since_token, room_builder.end_token, full_state=full_state, + joined=room_builder.rtype == "joined", ) else: # An out of band room won't have any state changes. diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py index d7913896d..afb71f4a0 100644 --- a/synapse/rest/admin/experimental_features.py +++ b/synapse/rest/admin/experimental_features.py @@ -43,12 +43,15 @@ class ExperimentalFeature(str, Enum): MSC3881 = "msc3881" MSC3575 = "msc3575" + MSC4222 = "msc4222" def is_globally_enabled(self, config: "HomeServerConfig") -> bool: if self is ExperimentalFeature.MSC3881: return config.experimental.msc3881_enabled if self is ExperimentalFeature.MSC3575: return config.experimental.msc3575_enabled + if self is ExperimentalFeature.MSC4222: + return config.experimental.msc4222_enabled assert_never(self) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 122708e93..5c62a74f4 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -152,6 +152,14 @@ class SyncRestServlet(RestServlet): filter_id = parse_string(request, "filter") full_state = parse_boolean(request, "full_state", default=False) + use_state_after = False + if await self.store.is_feature_enabled( + user.to_string(), ExperimentalFeature.MSC4222 + ): + use_state_after = parse_boolean( + request, "org.matrix.msc4222.use_state_after", default=False + ) + logger.debug( "/sync: user=%r, timeout=%r, since=%r, " "set_presence=%r, filter_id=%r, device_id=%r", @@ -184,6 +192,7 @@ class SyncRestServlet(RestServlet): full_state, device_id, last_ignore_accdata_streampos, + use_state_after, ) if filter_id is None: @@ -220,6 +229,7 @@ class SyncRestServlet(RestServlet): filter_collection=filter_collection, is_guest=requester.is_guest, device_id=device_id, + use_state_after=use_state_after, ) since_token = None @@ -258,7 +268,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester, filter_collection + time_now, sync_config, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -268,6 +278,7 @@ class SyncRestServlet(RestServlet): async def encode_response( self, time_now: int, + sync_config: SyncConfig, sync_result: SyncResult, requester: Requester, filter: FilterCollection, @@ -292,7 +303,7 @@ class SyncRestServlet(RestServlet): ) joined = await self.encode_joined( - sync_result.joined, time_now, serialize_options + sync_config, sync_result.joined, time_now, serialize_options ) invited = await self.encode_invited( @@ -304,7 +315,7 @@ class SyncRestServlet(RestServlet): ) archived = await self.encode_archived( - sync_result.archived, time_now, serialize_options + sync_config, sync_result.archived, time_now, serialize_options ) logger.debug("building sync response dict") @@ -372,6 +383,7 @@ class SyncRestServlet(RestServlet): @trace_with_opname("sync.encode_joined") async def encode_joined( self, + sync_config: SyncConfig, rooms: List[JoinedSyncResult], time_now: int, serialize_options: SerializeEventConfig, @@ -380,6 +392,7 @@ class SyncRestServlet(RestServlet): Encode the joined rooms in a sync result Args: + sync_config rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options @@ -389,7 +402,11 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=True, serialize_options=serialize_options + sync_config, + room, + time_now, + joined=True, + serialize_options=serialize_options, ) return joined @@ -477,6 +494,7 @@ class SyncRestServlet(RestServlet): @trace_with_opname("sync.encode_archived") async def encode_archived( self, + sync_config: SyncConfig, rooms: List[ArchivedSyncResult], time_now: int, serialize_options: SerializeEventConfig, @@ -485,6 +503,7 @@ class SyncRestServlet(RestServlet): Encode the archived rooms in a sync result Args: + sync_config rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options @@ -494,13 +513,18 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=False, serialize_options=serialize_options + sync_config, + room, + time_now, + joined=False, + serialize_options=serialize_options, ) return joined async def encode_room( self, + sync_config: SyncConfig, room: Union[JoinedSyncResult, ArchivedSyncResult], time_now: int, joined: bool, @@ -508,6 +532,7 @@ class SyncRestServlet(RestServlet): ) -> JsonDict: """ Args: + sync_config room: sync result for a single room time_now: current time - used as a baseline for age calculations token_id: ID of the user's auth token - used for namespacing @@ -548,13 +573,20 @@ class SyncRestServlet(RestServlet): account_data = room.account_data + # We either include a `state` or `state_after` field depending on + # whether the client has opted in to the newer `state_after` behavior. + if sync_config.use_state_after: + state_key_name = "org.matrix.msc4222.state_after" + else: + state_key_name = "state" + result: JsonDict = { "timeline": { "events": serialized_timeline, "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, - "state": {"events": serialized_state}, + state_key_name: {"events": serialized_state}, "account_data": {"events": account_data}, } @@ -688,6 +720,7 @@ class SlidingSyncE2eeRestServlet(RestServlet): filter_collection=self.only_member_events_filter_collection, is_guest=requester.is_guest, device_id=device_id, + use_state_after=False, # We don't return any rooms so this flag is a no-op ) since_token = None diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index d7bbc6803..1960d2f0e 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -20,7 +20,7 @@ from typing import Collection, ContextManager, List, Optional from unittest.mock import AsyncMock, Mock, patch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -32,7 +32,13 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json -from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion +from synapse.handlers.sync import ( + SyncConfig, + SyncRequestKey, + SyncResult, + SyncVersion, + TimelineBatch, +) from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer @@ -58,9 +64,21 @@ def generate_request_key() -> SyncRequestKey: return ("request_key", _request_key) +@parameterized_class( + ("use_state_after",), + [ + (True,), + (False,), + ], + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'state_after' if params_dict['use_state_after'] else 'state'}", +) class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" + use_state_after: bool + servlets = [ admin.register_servlets, knock.register_servlets, @@ -79,7 +97,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self) -> None: user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = generate_sync_config(user_id1) + sync_config = generate_sync_config( + user_id1, use_state_after=self.use_state_after + ) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -112,7 +132,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = generate_sync_config(user_id2) + sync_config = generate_sync_config( + user_id2, use_state_after=self.use_state_after + ) requester = create_requester(user_id2) e = self.get_failure( @@ -141,7 +163,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -175,7 +199,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user), + sync_config=generate_sync_config( + user, use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -188,7 +214,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_result.next_batch, @@ -220,7 +248,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user), + sync_config=generate_sync_config( + user, use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -233,7 +263,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_result.next_batch, @@ -276,7 +308,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): alice_sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(owner), - generate_sync_config(owner), + generate_sync_config(owner, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -296,7 +328,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Eve syncs. eve_requester = create_requester(eve) - eve_sync_config = generate_sync_config(eve) + eve_sync_config = generate_sync_config( + eve, use_state_after=self.use_state_after + ) eve_sync_after_ban: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( eve_requester, @@ -367,7 +401,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -396,6 +430,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 2}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -442,7 +477,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -481,6 +516,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): } }, ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -518,6 +554,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ... and a filter that means we only return 1 event, represented by the dashed horizontal lines: `S2` must be included in the `state` section on the second sync. + + When `use_state_after` is enabled, then we expect to see `s2` in the first sync. """ alice = self.register_user("alice", "password") alice_tok = self.login(alice, "password") @@ -528,7 +566,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -554,6 +592,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -567,10 +606,18 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e3_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [], - ) + + if self.use_state_after: + # When using `state_after` we get told about s2 immediately + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) # Now send another event that points to S2, but not E3. with self._patch_get_latest_events([s2_event]): @@ -585,6 +632,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -598,10 +646,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e4_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) + + if self.use_state_after: + # When using `state_after` we got told about s2 previously, so we + # don't again. + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) def test_state_includes_changes_on_ungappy_syncs(self) -> None: """Test `state` where the sync is not gappy. @@ -638,6 +695,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): This is the last chance for us to tell the client about S2, so it *must* be included in the response. + + When `use_state_after` is enabled, then we expect to see `s2` in the first sync. """ alice = self.register_user("alice", "password") alice_tok = self.login(alice, "password") @@ -648,7 +707,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -673,6 +732,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -684,7 +744,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e3_event], ) - self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) + if self.use_state_after: + # When using `state_after` we get told about s2 immediately + self.assertIn(s2_event, [e.event_id for e in room_sync.state.values()]) + else: + self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) # More events, E4 and E5 with self._patch_get_latest_events([e3_event]): @@ -695,7 +759,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): incremental_sync = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_sync_result.next_batch, @@ -710,10 +774,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e4_event, e5_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) + + if self.use_state_after: + # When using `state_after` we got told about s2 previously, so we + # don't again. + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) @parameterized.expand( [ @@ -721,7 +794,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): (True, False), (False, True), (True, True), - ] + ], + name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}_{p.args[1]}", ) def test_archived_rooms_do_not_include_state_after_leave( self, initial_sync: bool, empty_timeline: bool @@ -749,7 +823,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( bob_requester, - generate_sync_config(bob), + generate_sync_config(bob, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -780,7 +854,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( bob_requester, generate_sync_config( - bob, filter_collection=FilterCollection(self.hs, filter_dict) + bob, + filter_collection=FilterCollection(self.hs, filter_dict), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -791,7 +867,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): if empty_timeline: # The timeline should be empty self.assertEqual(sync_room_result.timeline.events, []) + else: + # The last three events in the timeline should be those leading up to the + # leave + self.assertEqual( + [e.event_id for e in sync_room_result.timeline.events[-3:]], + [before_message_event, before_state_event, leave_event], + ) + if empty_timeline or self.use_state_after: # And the state should include the leave event... self.assertEqual( sync_room_result.state[("m.room.member", bob)].event_id, leave_event @@ -801,12 +885,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_room_result.state[("test_state", "")].event_id, before_state_event ) else: - # The last three events in the timeline should be those leading up to the - # leave - self.assertEqual( - [e.event_id for e in sync_room_result.timeline.events[-3:]], - [before_message_event, before_state_event, leave_event], - ) # ... And the state should be empty self.assertEqual(sync_room_result.state, {}) @@ -879,7 +957,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -928,7 +1006,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): private_sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user2), - generate_sync_config(user2), + generate_sync_config(user2, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -954,7 +1032,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -991,7 +1069,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_d = defer.ensureDeferred( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=since_token, @@ -1046,7 +1124,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_d = defer.ensureDeferred( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=since_token, @@ -1062,6 +1140,7 @@ def generate_sync_config( user_id: str, device_id: Optional[str] = "device_id", filter_collection: Optional[FilterCollection] = None, + use_state_after: bool = False, ) -> SyncConfig: """Generate a sync config (with a unique request key). @@ -1069,7 +1148,8 @@ def generate_sync_config( user_id: user who is syncing. device_id: device that is syncing. Defaults to "device_id". filter_collection: filter to apply. Defaults to the default filter (ie, - return everything, with a default limit) + return everything, with a default limit) + use_state_after: whether the `use_state_after` flag was set. """ if filter_collection is None: filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION @@ -1079,4 +1159,106 @@ def generate_sync_config( filter_collection=filter_collection, is_guest=False, device_id=device_id, + use_state_after=use_state_after, ) + + +class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase): + """Tests Sync Handler state behavior when using `use_state_after.""" + + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_handler = self.hs.get_sync_handler() + self.store = self.hs.get_datastores().main + + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth_blocking() + + def test_initial_sync_multiple_deltas(self) -> None: + """Test that if multiple state deltas have happened during processing of + a full state sync we return the correct state""" + + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + first_state = self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 1}, tok=tok + ) + + # Take a snapshot of the stream token, to simulate doing an initial sync + # at this point. + end_stream_token = self.hs.get_event_sources().get_current_token() + + # Send some state *after* the stream token + self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 2}, tok=tok + ) + + # Calculating the full state will return the first state, and not the + # second. + state = self.get_success( + self.sync_handler._compute_state_delta_for_full_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + end_token=end_stream_token, + members_to_fetch=None, + timeline_state={}, + joined=True, + ) + ) + self.assertEqual(state[("m.test_event", "")], first_state["event_id"]) + + def test_incremental_sync_multiple_deltas(self) -> None: + """Test that if multiple state deltas have happened since an incremental + state sync we return the correct state""" + + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + # Take a snapshot of the stream token, to simulate doing an incremental sync + # from this point. + since_token = self.hs.get_event_sources().get_current_token() + + self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 1}, tok=tok + ) + + # Send some state *after* the stream token + second_state = self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 2}, tok=tok + ) + + end_stream_token = self.hs.get_event_sources().get_current_token() + + # Calculating the incrementals state will return the second state, and not the + # first. + state = self.get_success( + self.sync_handler._compute_state_delta_for_incremental_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + since_token=since_token, + end_token=end_stream_token, + members_to_fetch=None, + timeline_state={}, + ) + ) + self.assertEqual(state[("m.test_event", "")], second_state["event_id"])