0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Sliding sync: Correctly track which read receipts we have or have not sent down. (#17575)

Add connection tracking to the receipts extension.

Based on #17574

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
This commit is contained in:
Erik Johnston 2024-08-19 21:16:07 +01:00 committed by GitHub
parent 261e746281
commit 8b8d74d12f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 378 additions and 80 deletions

1
changelog.d/17575.misc Normal file
View file

@ -0,0 +1 @@
Correctly track read receipts that should be sent down in experimental sliding sync.

View file

@ -29,6 +29,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Final, Final,
Generic,
List, List,
Literal, Literal,
Mapping, Mapping,
@ -37,6 +38,7 @@ from typing import (
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
TypeVar,
Union, Union,
cast, cast,
) )
@ -55,6 +57,7 @@ from synapse.api.constants import (
from synapse.api.errors import SlidingSyncUnknownPosition from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.events import EventBase, StrippedStateEvent from synapse.events import EventBase, StrippedStateEvent
from synapse.events.utils import parse_stripped_state_event, strip_event from synapse.events.utils import parse_stripped_state_event, strip_event
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.relations import BundledAggregations from synapse.handlers.relations import BundledAggregations
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags, SynapseTags,
@ -821,7 +824,7 @@ class SlidingSyncHandler:
async def handle_room(room_id: str) -> None: async def handle_room(room_id: str) -> None:
room_sync_result = await self.get_room_sync_data( room_sync_result = await self.get_room_sync_data(
sync_config=sync_config, sync_config=sync_config,
per_connection_state=previous_connection_state, previous_connection_state=previous_connection_state,
room_id=room_id, room_id=room_id,
room_sync_config=relevant_rooms_to_send_map[room_id], room_sync_config=relevant_rooms_to_send_map[room_id],
room_membership_for_user_at_to_token=room_membership_for_user_map[ room_membership_for_user_at_to_token=room_membership_for_user_map[
@ -839,9 +842,13 @@ class SlidingSyncHandler:
with start_active_span("sliding_sync.generate_room_entries"): with start_active_span("sliding_sync.generate_room_entries"):
await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10) await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10)
new_connection_state = previous_connection_state.get_mutable()
extensions = await self.get_extensions_response( extensions = await self.get_extensions_response(
sync_config=sync_config, sync_config=sync_config,
actual_lists=lists, actual_lists=lists,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
# We're purposely using `relevant_room_map` instead of # We're purposely using `relevant_room_map` instead of
# `relevant_rooms_to_send_map` here. This needs to be all room_ids we could # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could
# send regardless of whether they have an event update or not. The # send regardless of whether they have an event update or not. The
@ -854,8 +861,6 @@ class SlidingSyncHandler:
) )
if has_lists or has_room_subscriptions: if has_lists or has_room_subscriptions:
new_connection_state = previous_connection_state.get_mutable()
# We now calculate if any rooms outside the range have had updates, # We now calculate if any rooms outside the range have had updates,
# which we are not sending down. # which we are not sending down.
# #
@ -886,7 +891,7 @@ class SlidingSyncHandler:
unsent_room_ids = list(missing_event_map_by_room) unsent_room_ids = list(missing_event_map_by_room)
new_connection_state.rooms.record_unsent_rooms( new_connection_state.rooms.record_unsent_rooms(
unsent_room_ids, from_token.stream_token unsent_room_ids, from_token.stream_token.room_key
) )
new_connection_state.rooms.record_sent_rooms( new_connection_state.rooms.record_sent_rooms(
@ -896,7 +901,7 @@ class SlidingSyncHandler:
connection_position = await self.connection_store.record_new_state( connection_position = await self.connection_store.record_new_state(
sync_config=sync_config, sync_config=sync_config,
from_token=from_token, from_token=from_token,
per_connection_state=new_connection_state, new_connection_state=new_connection_state,
) )
elif from_token: elif from_token:
connection_position = from_token.connection_position connection_position = from_token.connection_position
@ -1949,7 +1954,7 @@ class SlidingSyncHandler:
async def get_room_sync_data( async def get_room_sync_data(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
per_connection_state: "PerConnectionState", previous_connection_state: "PerConnectionState",
room_id: str, room_id: str,
room_sync_config: RoomSyncConfig, room_sync_config: RoomSyncConfig,
room_membership_for_user_at_to_token: _RoomMembershipForUser, room_membership_for_user_at_to_token: _RoomMembershipForUser,
@ -1997,7 +2002,7 @@ class SlidingSyncHandler:
from_bound = None from_bound = None
initial = True initial = True
if from_token and not room_membership_for_user_at_to_token.newly_joined: if from_token and not room_membership_for_user_at_to_token.newly_joined:
room_status = per_connection_state.rooms.have_sent_room(room_id) room_status = previous_connection_state.rooms.have_sent_room(room_id)
if room_status.status == HaveSentRoomFlag.LIVE: if room_status.status == HaveSentRoomFlag.LIVE:
from_bound = from_token.stream_token.room_key from_bound = from_token.stream_token.room_key
initial = False initial = False
@ -2476,6 +2481,8 @@ class SlidingSyncHandler:
async def get_extensions_response( async def get_extensions_response(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
previous_connection_state: "PerConnectionState",
new_connection_state: "MutablePerConnectionState",
actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: Set[str], actual_room_ids: Set[str],
actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
@ -2486,6 +2493,9 @@ class SlidingSyncHandler:
Args: Args:
sync_config: Sync configuration sync_config: Sync configuration
new_connection_state: Snapshot of the current per-connection state
new_per_connection_state: A mutable copy of the per-connection
state, used to record updates to the state during this request.
actual_lists: Sliding window API. A map of list key to list results in the actual_lists: Sliding window API. A map of list key to list results in the
Sliding Sync response. Sliding Sync response.
actual_room_ids: The actual room IDs in the the Sliding Sync response. actual_room_ids: The actual room IDs in the the Sliding Sync response.
@ -2530,6 +2540,8 @@ class SlidingSyncHandler:
if sync_config.extensions.receipts is not None: if sync_config.extensions.receipts is not None:
receipts_response = await self.get_receipts_extension_response( receipts_response = await self.get_receipts_extension_response(
sync_config=sync_config, sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
actual_lists=actual_lists, actual_lists=actual_lists,
actual_room_ids=actual_room_ids, actual_room_ids=actual_room_ids,
actual_room_response_map=actual_room_response_map, actual_room_response_map=actual_room_response_map,
@ -2849,6 +2861,8 @@ class SlidingSyncHandler:
async def get_receipts_extension_response( async def get_receipts_extension_response(
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
previous_connection_state: "PerConnectionState",
new_connection_state: "MutablePerConnectionState",
actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
actual_room_ids: Set[str], actual_room_ids: Set[str],
actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
@ -2860,6 +2874,9 @@ class SlidingSyncHandler:
Args: Args:
sync_config: Sync configuration sync_config: Sync configuration
previous_connection_state: The current per-connection state
new_connection_state: A mutable copy of the per-connection
state, used to record updates to the state.
actual_lists: Sliding window API. A map of list key to list results in the actual_lists: Sliding window API. A map of list key to list results in the
Sliding Sync response. Sliding Sync response.
actual_room_ids: The actual room IDs in the the Sliding Sync response. actual_room_ids: The actual room IDs in the the Sliding Sync response.
@ -2882,50 +2899,145 @@ class SlidingSyncHandler:
room_id_to_receipt_map: Dict[str, JsonMapping] = {} room_id_to_receipt_map: Dict[str, JsonMapping] = {}
if len(relevant_room_ids) > 0: if len(relevant_room_ids) > 0:
# TODO: Take connection tracking into account so that when a room comes back # We need to handle the different cases depending on if we have sent
# into range we can send the receipts that were missed. # down receipts previously or not, so we split the relevant rooms
receipt_source = self.event_sources.sources.receipt # up into different collections based on status.
receipts, _ = await receipt_source.get_new_events( live_rooms = set()
user=sync_config.user, previously_rooms: Dict[str, MultiWriterStreamToken] = {}
from_key=( initial_rooms = set()
from_token.stream_token.receipt_key
if from_token for room_id in relevant_room_ids:
else MultiWriterStreamToken(stream=0) if not from_token:
), initial_rooms.add(room_id)
continue
# If we're sending down the room from scratch again for some reason, we
# should always resend the receipts as well (regardless of if
# we've sent them down before). This is to mimic the behaviour
# of what happens on initial sync, where you get a chunk of
# timeline with all of the corresponding receipts for the events in the timeline.
room_result = actual_room_response_map.get(room_id)
if room_result is not None and room_result.initial:
initial_rooms.add(room_id)
continue
room_status = previous_connection_state.receipts.have_sent_room(room_id)
if room_status.status == HaveSentRoomFlag.LIVE:
live_rooms.add(room_id)
elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
assert room_status.last_token is not None
previously_rooms[room_id] = room_status.last_token
elif room_status.status == HaveSentRoomFlag.NEVER:
initial_rooms.add(room_id)
else:
assert_never(room_status.status)
# The set of receipts that we fetched. Private receipts need to be
# filtered out before returning.
fetched_receipts = []
# For live rooms we just fetch all receipts in those rooms since the
# `since` token.
if live_rooms:
assert from_token is not None
receipts = await self.store.get_linearized_receipts_for_rooms(
room_ids=live_rooms,
from_key=from_token.stream_token.receipt_key,
to_key=to_token.receipt_key,
)
fetched_receipts.extend(receipts)
# For rooms we've previously sent down, but aren't up to date, we
# need to use the from token from the room status.
if previously_rooms:
for room_id, receipt_token in previously_rooms.items():
# TODO: Limit the number of receipts we're about to send down
# for the room, if its too many we should TODO
previously_receipts = (
await self.store.get_linearized_receipts_for_room(
room_id=room_id,
from_key=receipt_token,
to_key=to_token.receipt_key,
)
)
fetched_receipts.extend(previously_receipts)
# For rooms we haven't previously sent down, we could send all receipts
# from that room but we only want to include receipts for events
# in the timeline to avoid bloating and blowing up the sync response
# as the number of users in the room increases. (this behavior is part of the spec)
for room_id in initial_rooms:
room_result = actual_room_response_map.get(room_id)
if room_result is None:
continue
relevant_event_ids = [
event.event_id for event in room_result.timeline_events
]
# TODO: In the future, it would be good to fetch less receipts
# out of the database in the first place but we would need to
# add a new `event_id` index to `receipts_linearized`.
initial_receipts = await self.store.get_linearized_receipts_for_room(
room_id=room_id,
to_key=to_token.receipt_key, to_key=to_token.receipt_key,
# This is a dummy value and isn't used in the function
limit=0,
room_ids=relevant_room_ids,
is_guest=False,
) )
for receipt in receipts: for receipt in initial_receipts:
content = {
event_id: content_value
for event_id, content_value in receipt["content"].items()
if event_id in relevant_event_ids
}
if content:
fetched_receipts.append(
{
"type": receipt["type"],
"room_id": receipt["room_id"],
"content": content,
}
)
fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
fetched_receipts, sync_config.user.to_string()
)
for receipt in fetched_receipts:
# These fields should exist for every receipt # These fields should exist for every receipt
room_id = receipt["room_id"] room_id = receipt["room_id"]
type = receipt["type"] type = receipt["type"]
content = receipt["content"] content = receipt["content"]
# For `inital: True` rooms, we only want to include receipts for events
# in the timeline.
room_result = actual_room_response_map.get(room_id)
if room_result is not None:
if room_result.initial:
# TODO: In the future, it would be good to fetch less receipts
# out of the database in the first place but we would need to
# add a new `event_id` index to `receipts_linearized`.
relevant_event_ids = [
event.event_id for event in room_result.timeline_events
]
assert isinstance(content, dict)
content = {
event_id: content_value
for event_id, content_value in content.items()
if event_id in relevant_event_ids
}
room_id_to_receipt_map[room_id] = {"type": type, "content": content} room_id_to_receipt_map[room_id] = {"type": type, "content": content}
# Now we update the per-connection state to track which receipts we have
# and haven't sent down.
new_connection_state.receipts.record_sent_rooms(relevant_room_ids)
if from_token:
# Now find the set of rooms that may have receipts that we're not sending
# down. We only need to check rooms that we have previously returned
# receipts for (in `previous_connection_state`) because we only care about
# updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just
# stay pointing at their previous position so we don't need to waste time
# checking those and since we default to `NEVER`, rooms that were `NEVER`
# sent before don't need to be recorded as we'll handle them correctly when
# they come into range for the first time.
rooms_no_receipts = [
room_id
for room_id, room_status in previous_connection_state.receipts._statuses.items()
if room_status.status == HaveSentRoomFlag.LIVE
and room_id not in relevant_room_ids
]
changed_rooms = await self.store.get_rooms_with_receipts_between(
rooms_no_receipts,
from_key=from_token.stream_token.receipt_key,
to_key=to_token.receipt_key,
)
new_connection_state.receipts.record_unsent_rooms(
changed_rooms, from_token.stream_token.receipt_key
)
return SlidingSyncResult.Extensions.ReceiptsExtension( return SlidingSyncResult.Extensions.ReceiptsExtension(
room_id_to_receipt_map=room_id_to_receipt_map, room_id_to_receipt_map=room_id_to_receipt_map,
) )
@ -3016,9 +3128,15 @@ class HaveSentRoomFlag(Enum):
LIVE = 3 LIVE = 3
T = TypeVar("T")
@attr.s(auto_attribs=True, slots=True, frozen=True) @attr.s(auto_attribs=True, slots=True, frozen=True)
class HaveSentRoom: class HaveSentRoom(Generic[T]):
"""Whether we have sent the room down a sliding sync connection. """Whether we have sent the room data down a sliding sync connection.
We are generic over the type of token used, e.g. `RoomStreamToken` or
`MultiWriterStreamToken`.
Attributes: Attributes:
status: Flag of if we have or haven't sent down the room status: Flag of if we have or haven't sent down the room
@ -3029,54 +3147,58 @@ class HaveSentRoom:
""" """
status: HaveSentRoomFlag status: HaveSentRoomFlag
last_token: Optional[RoomStreamToken] last_token: Optional[T]
@staticmethod @staticmethod
def previously(last_token: RoomStreamToken) -> "HaveSentRoom": def live() -> "HaveSentRoom[T]":
return HaveSentRoom(HaveSentRoomFlag.LIVE, None)
@staticmethod
def previously(last_token: T) -> "HaveSentRoom[T]":
"""Constructor for `PREVIOUSLY` flag.""" """Constructor for `PREVIOUSLY` flag."""
return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
@staticmethod
HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None) def never() -> "HaveSentRoom[T]":
HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) return HaveSentRoom(HaveSentRoomFlag.NEVER, None)
@attr.s(auto_attribs=True, slots=True, frozen=True) @attr.s(auto_attribs=True, slots=True, frozen=True)
class RoomStatusMap: class RoomStatusMap(Generic[T]):
"""For a given stream, e.g. events, records what we have or have not sent """For a given stream, e.g. events, records what we have or have not sent
down for that stream in a given room.""" down for that stream in a given room."""
# `room_id` -> `HaveSentRoom` # `room_id` -> `HaveSentRoom`
_statuses: Mapping[str, HaveSentRoom] = attr.Factory(dict) _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict)
def have_sent_room(self, room_id: str) -> HaveSentRoom: def have_sent_room(self, room_id: str) -> HaveSentRoom[T]:
"""Return whether we have previously sent the room down""" """Return whether we have previously sent the room down"""
return self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER) return self._statuses.get(room_id, HaveSentRoom.never())
def get_mutable(self) -> "MutableRoomStatusMap": def get_mutable(self) -> "MutableRoomStatusMap[T]":
"""Get a mutable copy of this state.""" """Get a mutable copy of this state."""
return MutableRoomStatusMap( return MutableRoomStatusMap(
statuses=self._statuses, statuses=self._statuses,
) )
def copy(self) -> "RoomStatusMap": def copy(self) -> "RoomStatusMap[T]":
"""Make a copy of the class. Useful for converting from a mutable to """Make a copy of the class. Useful for converting from a mutable to
immutable version.""" immutable version."""
return RoomStatusMap(statuses=dict(self._statuses)) return RoomStatusMap(statuses=dict(self._statuses))
class MutableRoomStatusMap(RoomStatusMap): class MutableRoomStatusMap(RoomStatusMap[T]):
"""A mutable version of `RoomStatusMap`""" """A mutable version of `RoomStatusMap`"""
# We use a ChainMap here so that we can easily track what has been updated # We use a ChainMap here so that we can easily track what has been updated
# and what hasn't. Note that when we persist the per connection state this # and what hasn't. Note that when we persist the per connection state this
# will get flattened to a normal dict (via calling `.copy()`) # will get flattened to a normal dict (via calling `.copy()`)
_statuses: typing.ChainMap[str, HaveSentRoom] _statuses: typing.ChainMap[str, HaveSentRoom[T]]
def __init__( def __init__(
self, self,
statuses: Mapping[str, HaveSentRoom], statuses: Mapping[str, HaveSentRoom[T]],
) -> None: ) -> None:
# ChainMap requires a mutable mapping, but we're not actually going to # ChainMap requires a mutable mapping, but we're not actually going to
# mutate it. # mutate it.
@ -3086,22 +3208,20 @@ class MutableRoomStatusMap(RoomStatusMap):
statuses=ChainMap({}, statuses), statuses=ChainMap({}, statuses),
) )
def get_updates(self) -> Mapping[str, HaveSentRoom]: def get_updates(self) -> Mapping[str, HaveSentRoom[T]]:
"""Return only the changes that were made""" """Return only the changes that were made"""
return self._statuses.maps[0] return self._statuses.maps[0]
def record_sent_rooms(self, room_ids: StrCollection) -> None: def record_sent_rooms(self, room_ids: StrCollection) -> None:
"""Record that we have sent these rooms in the response""" """Record that we have sent these rooms in the response"""
for room_id in room_ids: for room_id in room_ids:
current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER) current_status = self._statuses.get(room_id, HaveSentRoom.never())
if current_status.status == HaveSentRoomFlag.LIVE: if current_status.status == HaveSentRoomFlag.LIVE:
continue continue
self._statuses[room_id] = HAVE_SENT_ROOM_LIVE self._statuses[room_id] = HaveSentRoom.live()
def record_unsent_rooms( def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None:
self, room_ids: StrCollection, from_token: StreamToken
) -> None:
"""Record that we have not sent these rooms in the response, but there """Record that we have not sent these rooms in the response, but there
have been updates. have been updates.
""" """
@ -3116,33 +3236,42 @@ class MutableRoomStatusMap(RoomStatusMap):
# sent anything down this time either so we leave it as NEVER. # sent anything down this time either so we leave it as NEVER.
for room_id in room_ids: for room_id in room_ids:
current_status = self._statuses.get(room_id, HAVE_SENT_ROOM_NEVER) current_status = self._statuses.get(room_id, HaveSentRoom.never())
if current_status.status != HaveSentRoomFlag.LIVE: if current_status.status != HaveSentRoomFlag.LIVE:
continue continue
self._statuses[room_id] = HaveSentRoom.previously(from_token.room_key) self._statuses[room_id] = HaveSentRoom.previously(from_token)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class PerConnectionState: class PerConnectionState:
"""The per-connection state. A snapshot of what we've sent down the connection before. """The per-connection state. A snapshot of what we've sent down the
connection before.
Currently, we track whether we've sent down various aspects of a given room before. Currently, we track whether we've sent down various aspects of a given room
before.
We use the `rooms` field to store the position in the events stream for each room that we've previously sent to the client before. On the next request that includes the room, we can then send only what's changed since that recorded position. We use the `rooms` field to store the position in the events stream for each
room that we've previously sent to the client before. On the next request
that includes the room, we can then send only what's changed since that
recorded position.
Same goes for the `receipts` field so we only need to send the new receipts since the last time you made a sync request. Same goes for the `receipts` field so we only need to send the new receipts
since the last time you made a sync request.
Attributes: Attributes:
rooms: The status of each room for the events stream. rooms: The status of each room for the events stream.
receipts: The status of each room for the receipts stream.
""" """
rooms: RoomStatusMap = attr.Factory(RoomStatusMap) rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap)
receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap)
def get_mutable(self) -> "MutablePerConnectionState": def get_mutable(self) -> "MutablePerConnectionState":
"""Get a mutable copy of this state.""" """Get a mutable copy of this state."""
return MutablePerConnectionState( return MutablePerConnectionState(
rooms=self.rooms.get_mutable(), rooms=self.rooms.get_mutable(),
receipts=self.receipts.get_mutable(),
) )
@ -3150,10 +3279,11 @@ class PerConnectionState:
class MutablePerConnectionState(PerConnectionState): class MutablePerConnectionState(PerConnectionState):
"""A mutable version of `PerConnectionState`""" """A mutable version of `PerConnectionState`"""
rooms: MutableRoomStatusMap rooms: MutableRoomStatusMap[RoomStreamToken]
receipts: MutableRoomStatusMap[MultiWriterStreamToken]
def has_updates(self) -> bool: def has_updates(self) -> bool:
return bool(self.rooms.get_updates()) return bool(self.rooms.get_updates()) or bool(self.receipts.get_updates())
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -3233,7 +3363,7 @@ class SlidingSyncConnectionStore:
self, self,
sync_config: SlidingSyncConfig, sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken], from_token: Optional[SlidingSyncStreamToken],
per_connection_state: MutablePerConnectionState, new_connection_state: MutablePerConnectionState,
) -> int: ) -> int:
"""Record updated per-connection state, returning the connection """Record updated per-connection state, returning the connection
position associated with the new state. position associated with the new state.
@ -3245,7 +3375,7 @@ class SlidingSyncConnectionStore:
if from_token is not None: if from_token is not None:
prev_connection_token = from_token.connection_position prev_connection_token = from_token.connection_position
if not per_connection_state.has_updates(): if not new_connection_state.has_updates():
return prev_connection_token return prev_connection_token
conn_key = self._get_connection_key(sync_config) conn_key = self._get_connection_key(sync_config)
@ -3259,7 +3389,8 @@ class SlidingSyncConnectionStore:
# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
# don't grow forever. # don't grow forever.
sync_statuses[new_store_token] = PerConnectionState( sync_statuses[new_store_token] = PerConnectionState(
rooms=per_connection_state.rooms.copy(), rooms=new_connection_state.rooms.copy(),
receipts=new_connection_state.receipts.copy(),
) )
return new_store_token return new_store_token

View file

@ -51,10 +51,12 @@ from synapse.types import (
JsonMapping, JsonMapping,
MultiWriterStreamToken, MultiWriterStreamToken,
PersistedPosition, PersistedPosition,
StrCollection,
) )
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -550,6 +552,46 @@ class ReceiptsWorkerStore(SQLBaseStore):
return results return results
async def get_rooms_with_receipts_between(
self,
room_ids: StrCollection,
from_key: MultiWriterStreamToken,
to_key: MultiWriterStreamToken,
) -> StrCollection:
"""Given a set of room_ids, find out which ones (may) have receipts
between the two tokens (> `from_token` and <= `to_token`)."""
room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key.stream
)
if not room_ids:
return []
def f(txn: LoggingTransaction, room_ids: StrCollection) -> StrCollection:
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
sql = f"""
SELECT DISTINCT room_id FROM receipts_linearized
WHERE {clause} AND ? < stream_id AND stream_id <= ?
"""
args.append(from_key.stream)
args.append(to_key.get_max_stream_pos())
txn.execute(sql, args)
return [room_id for room_id, in txn]
results: List[str] = []
for batch in batch_iter(room_ids, 1000):
batch_result = await self.db_pool.runInteraction(
"get_rooms_with_receipts_between", f, batch
)
results.extend(batch_result)
return results
async def get_users_sent_receipts_between( async def get_users_sent_receipts_between(
self, last_id: int, current_id: int self, last_id: int, current_id: int
) -> List[str]: ) -> List[str]:

View file

@ -677,3 +677,108 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
set(), set(),
exact=True, exact=True,
) )
def test_receipts_incremental_sync_out_of_range(self) -> None:
"""Tests that we don't return read receipts for rooms that fall out of
range, but then do send all read receipts once they're back in range.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id1, user1_id, tok=user1_tok)
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id2, user1_id, tok=user1_tok)
# Send a message and read receipt into room2
event_response = self.helper.send(room_id2, body="new event", tok=user2_tok)
room2_event_id = event_response["event_id"]
self.helper.send_read_receipt(room_id2, room2_event_id, tok=user1_tok)
# Now send a message into room1 so that it is at the top of the list
self.helper.send(room_id1, body="new event", tok=user2_tok)
# Make a SS request for only the top room.
sync_body = {
"lists": {
"main": {
"ranges": [[0, 0]],
"required_state": [],
"timeline_limit": 5,
}
},
"extensions": {
"receipts": {
"enabled": True,
}
},
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
# The receipt is in room2, but only room1 is returned, so we don't
# expect to get the receipt.
self.assertIncludes(
response_body["extensions"]["receipts"].get("rooms").keys(),
set(),
exact=True,
)
# Move room2 into range.
self.helper.send(room_id2, body="new event", tok=user2_tok)
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We expect to see the read receipt of room2, as that has the most
# recent update.
self.assertIncludes(
response_body["extensions"]["receipts"].get("rooms").keys(),
{room_id2},
exact=True,
)
receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
self.assertIncludes(
receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
{user1_id},
exact=True,
)
# Send a message into room1 to bump it to the top, but also send a
# receipt in room2
self.helper.send(room_id1, body="new event", tok=user2_tok)
self.helper.send_read_receipt(room_id2, room2_event_id, tok=user2_tok)
# We don't expect to see the new read receipt.
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
self.assertIncludes(
response_body["extensions"]["receipts"].get("rooms").keys(),
set(),
exact=True,
)
# But if we send a new message into room2, we expect to get the missing receipts
self.helper.send(room_id2, body="new event", tok=user2_tok)
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
self.assertIncludes(
response_body["extensions"]["receipts"].get("rooms").keys(),
{room_id2},
exact=True,
)
# We should only see the new receipt
receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
self.assertIncludes(
receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
{user2_id},
exact=True,
)

View file

@ -120,19 +120,26 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase):
"foo-list": { "foo-list": {
"ranges": [[0, 1]], "ranges": [[0, 1]],
"required_state": [], "required_state": [],
"timeline_limit": 0, # We set this to `1` because we're testing `receipts` which
# interact with the `timeline`. With receipts, when a room
# hasn't been sent down the connection before or it appears
# as `initial: true`, we only include receipts for events in
# the timeline to avoid bloating and blowing up the sync
# response as the number of users in the room increases.
# (this behavior is part of the spec)
"timeline_limit": 1,
}, },
# We expect this list range to include room5, room4, room3 # We expect this list range to include room5, room4, room3
"bar-list": { "bar-list": {
"ranges": [[0, 2]], "ranges": [[0, 2]],
"required_state": [], "required_state": [],
"timeline_limit": 0, "timeline_limit": 1,
}, },
}, },
"room_subscriptions": { "room_subscriptions": {
room_id1: { room_id1: {
"required_state": [], "required_state": [],
"timeline_limit": 0, "timeline_limit": 1,
} }
}, },
} }

View file

@ -45,7 +45,7 @@ from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.server import Site from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
@ -944,3 +944,15 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page" assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0] oauth_uri = p.links[0]
return oauth_uri return oauth_uri
def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
"""Send a read receipt into the room at the given event"""
channel = make_request(
self.reactor,
self.site,
method="POST",
path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
content={},
access_token=tok,
)
assert channel.code == HTTPStatus.OK, channel.text_body