Add a relations handler to avoid duplication. (#12227)

Adds a handler layer between the REST and datastore layers for relations.
This commit is contained in:
Patrick Cloke 2022-03-16 10:39:15 -04:00 committed by GitHub
parent c486fa5fd9
commit fc9bd620ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 69 deletions

1
changelog.d/12227.misc Normal file
View file

@ -0,0 +1 @@
Refactor the relations endpoints to add a `RelationsHandler`.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr import attr
@ -422,7 +422,7 @@ class PaginationHandler:
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
as_client_event: bool = True, as_client_event: bool = True,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Dict[str, Any]: ) -> JsonDict:
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -431,6 +431,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any. pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format. as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None event_filter: Filter to apply to results or None
Returns: Returns:
Pagination API results Pagination API results
""" """

View file

@ -0,0 +1,117 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import SynapseError
from synapse.types import JsonDict, Requester, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
TODO Accept a PaginationConfig instead of individual pagination parameters.
Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self._event_handler.get_event(requester.user, room_id, event_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=aggregation_key,
limit=limit,
direction=direction,
from_token=from_token,
to_token=to_token,
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self._clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self._main_store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return return_value

View file

@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self._relations_handler = hs.get_relations_handler()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, self,
@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id, requester.user.to_string(), allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
direction = parse_string( direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str: if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str) to_token = await StreamToken.from_string(self.store, to_token_str)
pagination_chunk = await self.store.get_relations_for_event( result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id, event_id=parent_id,
event=event,
room_id=room_id, room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = await self.store.get_events_as_list( return 200, result
[c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return 200, return_value
class RelationAggregationPaginationServlet(RestServlet): class RelationAggregationPaginationServlet(RestServlet):
@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self._relations_handler = hs.get_relations_handler()
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, self,
@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
room_id,
requester.user.to_string(),
allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str: if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str) to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self.store.get_relations_for_event( result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id, event_id=parent_id,
event=event,
room_id=room_id, room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = await self.store.get_events_as_list( return 200, result
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events
return 200, return_value
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View file

@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import ( from synapse.handlers.room import (
RoomContextHandler, RoomContextHandler,
RoomCreationHandler, RoomCreationHandler,
@ -719,6 +720,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_pagination_handler(self) -> PaginationHandler: def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self) return PaginationHandler(self)
@cache_in_self
def get_relations_handler(self) -> RelationsHandler:
return RelationsHandler(self)
@cache_in_self @cache_in_self
def get_room_context_handler(self) -> RoomContextHandler: def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self) return RoomContextHandler(self)