diff --git a/changelog.d/15315.feature b/changelog.d/15315.feature new file mode 100644 index 000000000..30b2abdc6 --- /dev/null +++ b/changelog.d/15315.feature @@ -0,0 +1 @@ +Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 659967973..cab7ccf4b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -192,5 +192,10 @@ class ExperimentalConfig(Config): # MSC2659: Application service ping endpoint self.msc2659_enabled = experimental.get("msc2659_enabled", False) + # MSC3981: Recurse relations + self.msc3981_recurse_relations = experimental.get( + "msc3981_recurse_relations", False + ) + # MSC3970: Scope transaction IDs to devices self.msc3970_enabled = experimental.get("msc3970_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 1d09fdf13..482463516 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -85,6 +85,7 @@ class RelationsHandler: event_id: str, room_id: str, pagin_config: PaginationConfig, + recurse: bool, include_original_event: bool, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -98,6 +99,7 @@ class RelationsHandler: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. pagin_config: The pagination config rules to apply, if any. + recurse: Whether to recursively find relations. include_original_event: Whether to include the parent event. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -132,6 +134,7 @@ class RelationsHandler: direction=pagin_config.direction, from_token=pagin_config.from_token, to_token=pagin_config.to_token, + recurse=recurse, ) events = await self._main_store.get_events_as_list( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b8b296bc0..785dfa08d 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import Direction from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.relations import ThreadsNextBatch @@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet): self.auth = hs.get_auth() self._store = hs.get_datastores().main self._relations_handler = hs.get_relations_handler() + self._support_recurse = hs.config.experimental.msc3981_recurse_relations async def on_GET( self, @@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet): pagination_config = await PaginationConfig.from_request( self._store, request, default_limit=5, default_dir=Direction.BACKWARDS ) + if self._support_recurse: + recurse = parse_boolean( + request, "org.matrix.msc3981.recurse", default=False + ) + else: + recurse = False # The unstable version of this API returns an extra field for client # compatibility, see https://github.com/matrix-org/synapse/issues/12930. @@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet): event_id=parent_id, room_id=room_id, pagin_config=pagination_config, + recurse=recurse, include_original_event=include_original_event, relation_type=relation_type, event_type=event_type, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 3955a8a9a..4a6c6c724 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, + recurse: bool = False, ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. @@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore): oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. + recurse: Whether to recursively find relations. Returns: A tuple of: @@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore): # Ensure bad limits aren't being passed in. assert limit >= 0 - where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event.event_id, room_id] + where_clause = ["room_id = ?"] + where_args: List[Union[str, int]] = [room_id] is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: @@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore): if pagination_clause: where_clause.append(pagination_clause) - sql = """ - SELECT event_id, relation_type, sender, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) + # If a recursive query is requested then the filters are applied after + # recursively following relationships from the requested event to children + # up to 3-relations deep. + # + # If no recursion is needed then the event_relations table is queried + # for direct children of the requested event. + if recurse: + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relation_type, relates_to_id, 0 AS depth + FROM event_relations + WHERE relates_to_id = ? + UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.event_id = e.relates_to_id + WHERE depth <= 3 + ) + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering + FROM related_events + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ?; + """ % ( + " AND ".join(where_clause), + order, + order, + ) + else: + sql = """ + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) def _get_recent_references_for_event_txn( txn: LoggingTransaction, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - txn.execute(sql, where_args + [limit + 1]) + txn.execute(sql, [event.event_id] + where_args + [limit + 1]) events = [] topo_orderings: List[int] = [] @@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore): # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type, 0 depth + SELECT event_id, relates_to_id, relation_type, 0 AS depth FROM event_relations WHERE event_id = ? UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 @@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore): sql = """ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type, 0 depth + SELECT event_id, relates_to_id, relation_type, 0 AS depth FROM event_relations WHERE event_id = ? UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fbbbcb23f..75439416c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -30,6 +30,7 @@ from tests import unittest from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event +from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -949,6 +950,125 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) +class RecursiveRelationTestCase(BaseRelationsTestCase): + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) + def test_recursive_relations(self) -> None: + """Generate a complex, multi-level relationship tree and query it.""" + # Create a thread with a few messages in it. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + # Add annotations. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + ) + annotation_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1 + ) + annotation_2 = channel.json_body["event_id"] + + # Add a reference to part of the thread, then edit the reference and annotate it. + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2 + ) + reference_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1 + ) + annotation_3 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.test", + parent_id=reference_1, + ) + edit = channel.json_body["event_id"] + + # Also more events off the root. + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d") + annotation_4 = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + event_ids, + [ + thread_1, + thread_2, + annotation_1, + annotation_2, + reference_1, + annotation_3, + edit, + annotation_4, + ], + ) + + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) + def test_recursive_relations_with_filter(self) -> None: + """The event_type and rel_type still apply.""" + # Create a thread with a few messages in it. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_1 = channel.json_body["event_id"] + + # Add annotations. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1 + ) + annotation_1 = channel.json_body["event_id"] + + # Add a reference to part of the thread, then edit the reference and annotate it. + channel = self._send_relation( + RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1 + ) + reference_1 = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1 + ) + annotation_2 = channel.json_body["event_id"] + + # Fetch only annotations, but recursively. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(event_ids, [annotation_1, annotation_2]) + + # Fetch only m.reactions, but recursively. + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction" + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + # The above events should be returned in creation order. + event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(event_ids, [annotation_1]) + + class BundledAggregationsTestCase(BaseRelationsTestCase): """ See RelationsTestCase.test_edit for a similar test for edits.