Remove an unnecessary class from the relations code. ()

The PaginationChunk class attempted to bundle some properties
together, but really just caused callers to jump through hoops and
hid implementation details.
This commit is contained in:
Patrick Cloke 2022-03-31 07:13:49 -04:00 committed by GitHub
parent 15cdcf8f30
commit adbf975623
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 74 deletions
changelog.d
synapse
handlers
storage

1
changelog.d/12338.misc Normal file
View file

@ -0,0 +1 @@
Refactor relations code to remove an unnecessary class.

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
from typing import TYPE_CHECKING, Dict, Iterable, Optional
import attr
from frozendict import frozendict
@ -25,7 +25,6 @@ from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@ -116,7 +115,7 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
pagination_chunk = await self._main_store.get_relations_for_event(
related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@ -129,9 +128,7 @@ class RelationsHandler:
to_token=to_token,
)
events = await self._main_store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
events = await self._main_store.get_events_as_list(related_events)
events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None)
@ -152,9 +149,16 @@ class RelationsHandler:
events, now, bundle_aggregations=aggregations
)
return_value = await pagination_chunk.to_dict(self._main_store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return_value = {
"chunk": serialized_events,
"original_event": original_event,
}
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value
@ -196,11 +200,18 @@ class RelationsHandler:
if annotations:
aggregations.annotations = {"chunk": annotations}
references = await self._main_store.get_relations_for_event(
references, next_token = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
if references:
aggregations.references = {
"chunk": [{"event_id": event_id} for event_id in references]
}
if next_token:
aggregations.references["next_batch"] = await next_token.to_string(
self._main_store
)
# Store the bundled aggregations in the event metadata for later use.
return aggregations

View file

@ -37,7 +37,6 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
@ -71,7 +70,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
) -> Tuple[List[str], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@ -88,8 +87,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
A tuple of:
A list of related event IDs
The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@ -144,7 +145,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
) -> Tuple[List[str], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@ -154,7 +155,7 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
events.append({"event_id": row[0]})
events.append(row[0])
last_topo_id = row[2]
last_stream_id = row[3]
@ -177,9 +178,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)
return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn

View file

@ -1,53 +0,0 @@
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import attr
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@attr.s(slots=True, auto_attribs=True)
class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
chunk: The rows returned by pagination
next_batch: Token to fetch next set of results with, if
None then there are no more results.
prev_batch: Token to fetch previous set of results with, if
None then there are no previous results.
"""
chunk: List[JsonDict]
next_batch: Optional[Any] = None
prev_batch: Optional[Any] = None
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
d = {"chunk": self.chunk}
if self.next_batch:
d["next_batch"] = await self.next_batch.to_string(store)
if self.prev_batch:
d["prev_batch"] = await self.prev_batch.to_string(store)
return d