forked from MirrorHub/synapse
Support pagination tokens from /sync and /messages in the relations API. (#11952)
This commit is contained in:
parent
337f38cac3
commit
df36945ff0
5 changed files with 217 additions and 53 deletions
1
changelog.d/11952.bugfix
Normal file
1
changelog.d/11952.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.
|
|
@ -32,14 +32,45 @@ from synapse.storage.relations import (
|
|||
PaginationChunk,
|
||||
RelationPaginationToken,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _parse_token(
|
||||
store: "DataStore", token: Optional[str]
|
||||
) -> Optional[StreamToken]:
|
||||
"""
|
||||
For backwards compatibility support RelationPaginationToken, but new pagination
|
||||
tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
# Luckily the format for StreamToken and RelationPaginationToken differ enough
|
||||
# that they can easily be separated. An "_" appears in the serialization of
|
||||
# RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
|
||||
# "-" only for separators.
|
||||
if "_" in token:
|
||||
return await StreamToken.from_string(store, token)
|
||||
else:
|
||||
relation_token = RelationPaginationToken.from_string(token)
|
||||
return StreamToken(
|
||||
room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=0,
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=0,
|
||||
groups_key=0,
|
||||
)
|
||||
|
||||
|
||||
class RelationPaginationServlet(RestServlet):
|
||||
"""API to paginate relations on an event by topological ordering, optionally
|
||||
filtered by relation type and event type.
|
||||
|
@ -88,13 +119,8 @@ class RelationPaginationServlet(RestServlet):
|
|||
pagination_chunk = PaginationChunk(chunk=[])
|
||||
else:
|
||||
# Return the relations
|
||||
from_token = None
|
||||
if from_token_str:
|
||||
from_token = RelationPaginationToken.from_string(from_token_str)
|
||||
|
||||
to_token = None
|
||||
if to_token_str:
|
||||
to_token = RelationPaginationToken.from_string(to_token_str)
|
||||
from_token = await _parse_token(self.store, from_token_str)
|
||||
to_token = await _parse_token(self.store, to_token_str)
|
||||
|
||||
pagination_chunk = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
|
@ -125,7 +151,7 @@ class RelationPaginationServlet(RestServlet):
|
|||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
return_value = pagination_chunk.to_dict()
|
||||
return_value = await pagination_chunk.to_dict(self.store)
|
||||
return_value["chunk"] = serialized_events
|
||||
return_value["original_event"] = original_event
|
||||
|
||||
|
@ -216,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet):
|
|||
to_token=to_token,
|
||||
)
|
||||
|
||||
return 200, pagination_chunk.to_dict()
|
||||
return 200, await pagination_chunk.to_dict(self.store)
|
||||
|
||||
|
||||
class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
|
@ -287,13 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||
from_token_str = parse_string(request, "from")
|
||||
to_token_str = parse_string(request, "to")
|
||||
|
||||
from_token = None
|
||||
if from_token_str:
|
||||
from_token = RelationPaginationToken.from_string(from_token_str)
|
||||
|
||||
to_token = None
|
||||
if to_token_str:
|
||||
to_token = RelationPaginationToken.from_string(to_token_str)
|
||||
from_token = await _parse_token(self.store, from_token_str)
|
||||
to_token = await _parse_token(self.store, to_token_str)
|
||||
|
||||
result = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
|
@ -313,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||
now = self.clock.time_msec()
|
||||
serialized_events = self._event_serializer.serialize_events(events, now)
|
||||
|
||||
return_value = result.to_dict()
|
||||
return_value = await result.to_dict(self.store)
|
||||
return_value["chunk"] = serialized_events
|
||||
|
||||
return 200, return_value
|
||||
|
|
|
@ -39,16 +39,13 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.relations import (
|
||||
AggregationPaginationToken,
|
||||
PaginationChunk,
|
||||
RelationPaginationToken,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
aggregation_key: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[RelationPaginationToken] = None,
|
||||
to_token: Optional[RelationPaginationToken] = None,
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> PaginationChunk:
|
||||
"""Get a list of relations for an event, ordered by topological ordering.
|
||||
|
||||
|
@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
pagination_clause = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
|
||||
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
|
||||
from_token=from_token.room_key.as_historical_tuple()
|
||||
if from_token
|
||||
else None,
|
||||
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
|
||||
engine=self.database_engine,
|
||||
)
|
||||
|
||||
|
@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
last_topo_id = row[1]
|
||||
last_stream_id = row[2]
|
||||
|
||||
next_batch = None
|
||||
# If there are more events, generate the next pagination key.
|
||||
next_token = None
|
||||
if len(events) > limit and last_topo_id and last_stream_id:
|
||||
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
|
||||
next_key = RoomStreamToken(last_topo_id, last_stream_id)
|
||||
if from_token:
|
||||
next_token = from_token.copy_and_replace("room_key", next_key)
|
||||
else:
|
||||
next_token = StreamToken(
|
||||
room_key=next_key,
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=0,
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=0,
|
||||
groups_key=0,
|
||||
)
|
||||
|
||||
return PaginationChunk(
|
||||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
||||
if annotations.chunk:
|
||||
aggregations.annotations = annotations.to_dict()
|
||||
aggregations.annotations = await annotations.to_dict(
|
||||
cast("DataStore", self)
|
||||
)
|
||||
|
||||
references = await self.get_relations_for_event(
|
||||
event_id, room_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
if references.chunk:
|
||||
aggregations.references = references.to_dict()
|
||||
aggregations.references = await references.to_dict(cast("DataStore", self))
|
||||
|
||||
# If this event is the start of a thread, include a summary of the replies.
|
||||
if self._msc3440_enabled:
|
||||
|
|
|
@ -13,13 +13,16 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -39,14 +42,14 @@ class PaginationChunk:
|
|||
next_batch: Optional[Any] = None
|
||||
prev_batch: Optional[Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
|
||||
d = {"chunk": self.chunk}
|
||||
|
||||
if self.next_batch:
|
||||
d["next_batch"] = self.next_batch.to_string()
|
||||
d["next_batch"] = await self.next_batch.to_string(store)
|
||||
|
||||
if self.prev_batch:
|
||||
d["prev_batch"] = self.prev_batch.to_string()
|
||||
d["prev_batch"] = await self.prev_batch.to_string(store)
|
||||
|
||||
return d
|
||||
|
||||
|
@ -75,7 +78,7 @@ class RelationPaginationToken:
|
|||
except ValueError:
|
||||
raise SynapseError(400, "Invalid relation pagination token")
|
||||
|
||||
def to_string(self) -> str:
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
return "%d-%d" % (self.topological, self.stream)
|
||||
|
||||
def as_tuple(self) -> Tuple[Any, ...]:
|
||||
|
@ -105,7 +108,7 @@ class AggregationPaginationToken:
|
|||
except ValueError:
|
||||
raise SynapseError(400, "Invalid aggregation pagination token")
|
||||
|
||||
def to_string(self) -> str:
|
||||
async def to_string(self, store: "DataStore") -> str:
|
||||
return "%d-%d" % (self.count, self.stream)
|
||||
|
||||
def as_tuple(self) -> Tuple[Any, ...]:
|
||||
|
|
|
@ -21,7 +21,8 @@ from unittest.mock import patch
|
|||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, register, relations, room, sync
|
||||
from synapse.types import JsonDict
|
||||
from synapse.storage.relations import RelationPaginationToken
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
|
@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
channel.json_body.get("next_batch"), str, channel.json_body
|
||||
)
|
||||
|
||||
def _stream_token_to_relation_token(self, token: str) -> str:
|
||||
"""Convert a StreamToken into a legacy token (RelationPaginationToken)."""
|
||||
room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
|
||||
return self.get_success(
|
||||
RelationPaginationToken(
|
||||
topological=room_key.topological, stream=room_key.stream
|
||||
).to_string(self.store)
|
||||
)
|
||||
|
||||
def test_repeated_paginate_relations(self):
|
||||
"""Test that if we paginate using a limit and tokens then we get the
|
||||
expected events.
|
||||
|
@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
expected_event_ids.append(channel.json_body["event_id"])
|
||||
|
||||
prev_token: Optional[str] = None
|
||||
prev_token = ""
|
||||
found_event_ids: List[str] = []
|
||||
for _ in range(20):
|
||||
from_token = ""
|
||||
|
@ -222,8 +232,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
|
||||
% (self.room, self.parent_id, from_token),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
@ -241,6 +250,93 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
found_event_ids.reverse()
|
||||
self.assertEquals(found_event_ids, expected_event_ids)
|
||||
|
||||
# Reset and try again, but convert the tokens to the legacy format.
|
||||
prev_token = ""
|
||||
found_event_ids = []
|
||||
for _ in range(20):
|
||||
from_token = ""
|
||||
if prev_token:
|
||||
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
||||
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
|
||||
next_batch = channel.json_body.get("next_batch")
|
||||
|
||||
self.assertNotEquals(prev_token, next_batch)
|
||||
prev_token = next_batch
|
||||
|
||||
if not prev_token:
|
||||
break
|
||||
|
||||
# We paginated backwards, so reverse
|
||||
found_event_ids.reverse()
|
||||
self.assertEquals(found_event_ids, expected_event_ids)
|
||||
|
||||
def test_pagination_from_sync_and_messages(self):
|
||||
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
annotation_id = channel.json_body["event_id"]
|
||||
# Send an event after the relation events.
|
||||
self.helper.send(self.room, body="Latest event", tok=self.user_token)
|
||||
|
||||
# Request /sync, limiting it such that only the latest event is returned
|
||||
# (and not the relation).
|
||||
filter = urllib.parse.quote_plus(
|
||||
'{"room": {"timeline": {"limit": 1}}}'.encode()
|
||||
)
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?filter={filter}", access_token=self.user_token
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
||||
sync_prev_batch = room_timeline["prev_batch"]
|
||||
self.assertIsNotNone(sync_prev_batch)
|
||||
# Ensure the relation event is not in the batch returned from /sync.
|
||||
self.assertNotIn(
|
||||
annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
|
||||
)
|
||||
|
||||
# Request /messages, limiting it such that only the latest event is
|
||||
# returned (and not the relation).
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{self.room}/messages?dir=b&limit=1",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
messages_end = channel.json_body["end"]
|
||||
self.assertIsNotNone(messages_end)
|
||||
# Ensure the relation event is not in the chunk returned from /messages.
|
||||
self.assertNotIn(
|
||||
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
)
|
||||
|
||||
# Request /relations with the pagination tokens received from both the
|
||||
# /sync and /messages responses above, in turn.
|
||||
#
|
||||
# This is a tiny bit silly since the client wouldn't know the parent ID
|
||||
# from the requests above; consider the parent ID to be known from a
|
||||
# previous /sync.
|
||||
for from_token in (sync_prev_batch, messages_end):
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
||||
# The relation should be in the returned chunk.
|
||||
self.assertIn(
|
||||
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
)
|
||||
|
||||
def test_aggregation_pagination_groups(self):
|
||||
"""Test that we can paginate annotation groups correctly."""
|
||||
|
||||
|
@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
||||
prev_token: Optional[str] = None
|
||||
prev_token = ""
|
||||
found_event_ids: List[str] = []
|
||||
encoded_key = urllib.parse.quote_plus("👍".encode())
|
||||
for _ in range(20):
|
||||
|
@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s"
|
||||
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
|
||||
% (
|
||||
self.room,
|
||||
self.parent_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
encoded_key,
|
||||
from_token,
|
||||
),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}"
|
||||
f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
|
||||
f"/m.reaction/{encoded_key}?limit=1{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
|
||||
|
||||
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
|
||||
|
||||
next_batch = channel.json_body.get("next_batch")
|
||||
|
||||
self.assertNotEquals(prev_token, next_batch)
|
||||
prev_token = next_batch
|
||||
|
||||
if not prev_token:
|
||||
break
|
||||
|
||||
# We paginated backwards, so reverse
|
||||
found_event_ids.reverse()
|
||||
self.assertEquals(found_event_ids, expected_event_ids)
|
||||
|
||||
# Reset and try again, but convert the tokens to the legacy format.
|
||||
prev_token = ""
|
||||
found_event_ids = []
|
||||
for _ in range(20):
|
||||
from_token = ""
|
||||
if prev_token:
|
||||
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}"
|
||||
f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
|
||||
f"/m.reaction/{encoded_key}?limit=1{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
|
|
Loading…
Reference in a new issue