forked from MirrorHub/synapse
Add an API for listing threads in a room. (#13394)
Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes.
This commit is contained in:
parent
b6baa46db0
commit
3bbe532abb
10 changed files with 522 additions and 6 deletions
1
changelog.d/13394.feature
Normal file
1
changelog.d/13394.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
|
|
@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import (
|
||||||
RegistrationBackgroundUpdateStore,
|
RegistrationBackgroundUpdateStore,
|
||||||
find_max_generated_user_id_localpart,
|
find_max_generated_user_id_localpart,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.relations import RelationsWorkerStore
|
||||||
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
|
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
|
||||||
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
|
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
|
||||||
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
|
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
|
||||||
|
@ -206,6 +207,7 @@ class Store(
|
||||||
PusherWorkerStore,
|
PusherWorkerStore,
|
||||||
PresenceBackgroundUpdateStore,
|
PresenceBackgroundUpdateStore,
|
||||||
ReceiptsBackgroundUpdateStore,
|
ReceiptsBackgroundUpdateStore,
|
||||||
|
RelationsWorkerStore,
|
||||||
):
|
):
|
||||||
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||||
|
|
|
@ -101,6 +101,9 @@ class ExperimentalConfig(Config):
|
||||||
# MSC3848: Introduce errcodes for specific event sending failures
|
# MSC3848: Introduce errcodes for specific event sending failures
|
||||||
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
|
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
|
||||||
|
|
||||||
|
# MSC3856: Threads list API
|
||||||
|
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
|
||||||
|
|
||||||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase, relation_from_event
|
from synapse.events import EventBase, relation_from_event
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
from synapse.storage.databases.main.relations import _RelatedEvent
|
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
from synapse.types import JsonDict, Requester, StreamToken, UserID
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -32,6 +33,13 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadsListInclude(str, enum.Enum):
|
||||||
|
"""Valid values for the 'include' flag of /threads."""
|
||||||
|
|
||||||
|
all = "all"
|
||||||
|
participated = "participated"
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class _ThreadAggregation:
|
class _ThreadAggregation:
|
||||||
# The latest event in the thread.
|
# The latest event in the thread.
|
||||||
|
@ -482,3 +490,79 @@ class RelationsHandler:
|
||||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def get_threads(
|
||||||
|
self,
|
||||||
|
requester: Requester,
|
||||||
|
room_id: str,
|
||||||
|
include: ThreadsListInclude,
|
||||||
|
limit: int = 5,
|
||||||
|
from_token: Optional[ThreadsNextBatch] = None,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""Get related events of a event, ordered by topological ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester: The user requesting the relations.
|
||||||
|
room_id: The room the event belongs to.
|
||||||
|
include: One of "all" or "participated" to indicate which threads should
|
||||||
|
be returned.
|
||||||
|
limit: Only fetch the most recent `limit` events.
|
||||||
|
from_token: Fetch rows from the given token, or from the start if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The pagination chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
# TODO Properly handle a user leaving a room.
|
||||||
|
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, requester, allow_departed_users=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note that ignored users are not passed into get_relations_for_event
|
||||||
|
# below. Ignored users are handled in filter_events_for_client (and by
|
||||||
|
# not passing them in here we should get a better cache hit rate).
|
||||||
|
thread_roots, next_batch = await self._main_store.get_threads(
|
||||||
|
room_id=room_id, limit=limit, from_token=from_token
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await self._main_store.get_events_as_list(thread_roots)
|
||||||
|
|
||||||
|
if include == ThreadsListInclude.participated:
|
||||||
|
# Pre-seed thread participation with whether the requester sent the event.
|
||||||
|
participated = {event.event_id: event.sender == user_id for event in events}
|
||||||
|
# For events the requester did not send, check the database for whether
|
||||||
|
# the requester sent a threaded reply.
|
||||||
|
participated.update(
|
||||||
|
await self._main_store.get_threads_participated(
|
||||||
|
[eid for eid, p in participated.items() if not p],
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Limit the returned threads to those the user has participated in.
|
||||||
|
events = [event for event in events if participated[event.event_id]]
|
||||||
|
|
||||||
|
events = await filter_events_for_client(
|
||||||
|
self._storage_controllers,
|
||||||
|
user_id,
|
||||||
|
events,
|
||||||
|
is_peeking=(member_event_id is None),
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregations = await self.get_bundled_aggregations(
|
||||||
|
events, requester.user.to_string()
|
||||||
|
)
|
||||||
|
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
serialized_events = self._event_serializer.serialize_events(
|
||||||
|
events, now, bundle_aggregations=aggregations
|
||||||
|
)
|
||||||
|
|
||||||
|
return_value: JsonDict = {"chunk": serialized_events}
|
||||||
|
|
||||||
|
if next_batch:
|
||||||
|
return_value["next_batch"] = str(next_batch)
|
||||||
|
|
||||||
|
return return_value
|
||||||
|
|
|
@ -13,12 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
|
from synapse.handlers.relations import ThreadsListInclude
|
||||||
from synapse.http.server import HttpServer
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.client._base import client_patterns
|
from synapse.rest.client._base import client_patterns
|
||||||
|
from synapse.storage.databases.main.relations import ThreadsNextBatch
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet):
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadsServlet(RestServlet):
|
||||||
|
PATTERNS = (
|
||||||
|
re.compile(
|
||||||
|
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
limit = parse_integer(request, "limit", default=5)
|
||||||
|
from_token_str = parse_string(request, "from")
|
||||||
|
include = parse_string(
|
||||||
|
request,
|
||||||
|
"include",
|
||||||
|
default=ThreadsListInclude.all.value,
|
||||||
|
allowed_values=[v.value for v in ThreadsListInclude],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the relations
|
||||||
|
from_token = None
|
||||||
|
if from_token_str:
|
||||||
|
from_token = ThreadsNextBatch.from_string(from_token_str)
|
||||||
|
|
||||||
|
result = await self._relations_handler.get_threads(
|
||||||
|
requester=requester,
|
||||||
|
room_id=room_id,
|
||||||
|
include=ThreadsListInclude(include),
|
||||||
|
limit=limit,
|
||||||
|
from_token=from_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, result
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
RelationPaginationServlet(hs).register(http_server)
|
RelationPaginationServlet(hs).register(http_server)
|
||||||
|
if hs.config.experimental.msc3856_enabled:
|
||||||
|
ThreadsServlet(hs).register(http_server)
|
||||||
|
|
|
@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
|
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
|
||||||
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
|
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
|
||||||
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
|
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
|
||||||
|
self._attempt_to_invalidate_cache("get_threads", (room_id,))
|
||||||
|
|
||||||
async def invalidate_cache_and_stream(
|
async def invalidate_cache_and_stream(
|
||||||
self, cache_name: str, keys: Tuple[Any, ...]
|
self, cache_name: str, keys: Tuple[Any, ...]
|
||||||
|
|
|
@ -35,7 +35,7 @@ import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
from synapse.api.constants import EventContentFields, EventTypes
|
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.events import EventBase, relation_from_event
|
from synapse.events import EventBase, relation_from_event
|
||||||
|
@ -1616,7 +1616,7 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove from relations table.
|
# Remove from relations table.
|
||||||
self._handle_redact_relations(txn, event.redacts)
|
self._handle_redact_relations(txn, event.room_id, event.redacts)
|
||||||
|
|
||||||
# Update the event_forward_extremities, event_backward_extremities and
|
# Update the event_forward_extremities, event_backward_extremities and
|
||||||
# event_edges tables.
|
# event_edges tables.
|
||||||
|
@ -1866,6 +1866,34 @@ class PersistEventsStore:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if relation.rel_type == RelationTypes.THREAD:
|
||||||
|
# Upsert into the threads table, but only overwrite the value if the
|
||||||
|
# new event is of a later topological order OR if the topological
|
||||||
|
# ordering is equal, but the stream ordering is later.
|
||||||
|
sql = """
|
||||||
|
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (room_id, thread_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
latest_event_id = excluded.latest_event_id,
|
||||||
|
topological_ordering = excluded.topological_ordering,
|
||||||
|
stream_ordering = excluded.stream_ordering
|
||||||
|
WHERE
|
||||||
|
threads.topological_ordering <= excluded.topological_ordering AND
|
||||||
|
threads.stream_ordering < excluded.stream_ordering
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(
|
||||||
|
event.room_id,
|
||||||
|
relation.parent_id,
|
||||||
|
event.event_id,
|
||||||
|
event.depth,
|
||||||
|
event.internal_metadata.stream_ordering,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_insertion_event(
|
def _handle_insertion_event(
|
||||||
self, txn: LoggingTransaction, event: EventBase
|
self, txn: LoggingTransaction, event: EventBase
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1989,13 +2017,14 @@ class PersistEventsStore:
|
||||||
txn.execute(sql, (batch_id,))
|
txn.execute(sql, (batch_id,))
|
||||||
|
|
||||||
def _handle_redact_relations(
|
def _handle_redact_relations(
|
||||||
self, txn: LoggingTransaction, redacted_event_id: str
|
self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handles receiving a redaction and checking whether the redacted event
|
"""Handles receiving a redaction and checking whether the redacted event
|
||||||
has any relations which must be removed from the database.
|
has any relations which must be removed from the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn
|
txn
|
||||||
|
room_id: The room ID of the event that was redacted.
|
||||||
redacted_event_id: The event that was redacted.
|
redacted_event_id: The event that was redacted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -2024,6 +2053,9 @@ class PersistEventsStore:
|
||||||
self.store._invalidate_cache_and_stream(
|
self.store._invalidate_cache_and_stream(
|
||||||
txn, self.store.get_thread_participated, (redacted_relates_to,)
|
txn, self.store.get_thread_participated, (redacted_relates_to,)
|
||||||
)
|
)
|
||||||
|
self.store._invalidate_cache_and_stream(
|
||||||
|
txn, self.store.get_threads, (room_id,)
|
||||||
|
)
|
||||||
|
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
|
@ -29,17 +30,46 @@ from typing import (
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
|
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
make_in_list_sql_clause,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
|
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ThreadsNextBatch:
|
||||||
|
topological_ordering: int
|
||||||
|
stream_ordering: int
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.topological_ordering}_{self.stream_ordering}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, string: str) -> "ThreadsNextBatch":
|
||||||
|
"""
|
||||||
|
Creates a ThreadsNextBatch from its textual representation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keys = (int(s) for s in string.split("_"))
|
||||||
|
return cls(*keys)
|
||||||
|
except Exception:
|
||||||
|
raise SynapseError(400, "Invalid threads token")
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class _RelatedEvent:
|
class _RelatedEvent:
|
||||||
"""
|
"""
|
||||||
|
@ -56,6 +86,76 @@ class _RelatedEvent:
|
||||||
|
|
||||||
|
|
||||||
class RelationsWorkerStore(SQLBaseStore):
|
class RelationsWorkerStore(SQLBaseStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_update_handler(
|
||||||
|
"threads_backfill", self._backfill_threads
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
|
||||||
|
"""Backfill the threads table."""
|
||||||
|
|
||||||
|
def threads_backfill_txn(txn: LoggingTransaction) -> int:
|
||||||
|
last_thread_id = progress.get("last_thread_id", "")
|
||||||
|
|
||||||
|
# Get the latest event in each thread by topo ordering / stream ordering.
|
||||||
|
#
|
||||||
|
# Note that the MAX(event_id) is needed to abide by the rules of group by,
|
||||||
|
# but doesn't actually do anything since there should only be a single event
|
||||||
|
# ID per topo/stream ordering pair.
|
||||||
|
sql = f"""
|
||||||
|
SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
|
||||||
|
FROM event_relations
|
||||||
|
INNER JOIN events USING (event_id)
|
||||||
|
WHERE
|
||||||
|
relates_to_id > ? AND
|
||||||
|
relation_type = '{RelationTypes.THREAD}'
|
||||||
|
GROUP BY room_id, relates_to_id
|
||||||
|
ORDER BY relates_to_id
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (last_thread_id, batch_size))
|
||||||
|
|
||||||
|
# No more rows to process.
|
||||||
|
rows = txn.fetchall()
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Insert the rows into the threads table. If a matching thread already exists,
|
||||||
|
# assume it is from a newer event.
|
||||||
|
sql = """
|
||||||
|
INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (room_id, thread_id)
|
||||||
|
DO NOTHING
|
||||||
|
"""
|
||||||
|
if isinstance(txn.database_engine, PostgresEngine):
|
||||||
|
txn.execute_values(sql % ("?",), rows, fetch=False)
|
||||||
|
else:
|
||||||
|
txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
|
||||||
|
|
||||||
|
# Mark the progress.
|
||||||
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
|
txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
|
||||||
|
)
|
||||||
|
|
||||||
|
return txn.rowcount
|
||||||
|
|
||||||
|
result = await self.db_pool.runInteraction(
|
||||||
|
"threads_backfill", threads_backfill_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
await self.db_pool.updates._end_background_update("threads_backfill")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@cached(uncached_args=("event",), tree=True)
|
@cached(uncached_args=("event",), tree=True)
|
||||||
async def get_relations_for_event(
|
async def get_relations_for_event(
|
||||||
self,
|
self,
|
||||||
|
@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(tree=True)
|
||||||
|
async def get_threads(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
limit: int = 5,
|
||||||
|
from_token: Optional[ThreadsNextBatch] = None,
|
||||||
|
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
|
||||||
|
"""Get a list of thread IDs, ordered by topological ordering of their
|
||||||
|
latest reply.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: The room the event belongs to.
|
||||||
|
limit: Only fetch the most recent `limit` threads.
|
||||||
|
from_token: Fetch rows from a previous next_batch, or from the start if None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of:
|
||||||
|
A list of thread root event IDs.
|
||||||
|
|
||||||
|
The next_batch, if one exists.
|
||||||
|
"""
|
||||||
|
# Generate the pagination clause, if necessary.
|
||||||
|
#
|
||||||
|
# Find any threads where the latest reply is equal / before the last
|
||||||
|
# thread's topo ordering and earlier in stream ordering.
|
||||||
|
pagination_clause = ""
|
||||||
|
pagination_args: tuple = ()
|
||||||
|
if from_token:
|
||||||
|
pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
|
||||||
|
pagination_args = (
|
||||||
|
from_token.topological_ordering,
|
||||||
|
from_token.stream_ordering,
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT thread_id, topological_ordering, stream_ordering
|
||||||
|
FROM threads
|
||||||
|
WHERE
|
||||||
|
room_id = ?
|
||||||
|
{pagination_clause}
|
||||||
|
ORDER BY topological_ordering DESC, stream_ordering DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_threads_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
|
||||||
|
txn.execute(sql, (room_id, *pagination_args, limit + 1))
|
||||||
|
|
||||||
|
rows = cast(List[Tuple[str, int, int]], txn.fetchall())
|
||||||
|
thread_ids = [r[0] for r in rows]
|
||||||
|
|
||||||
|
# If there are more events, generate the next pagination key from the
|
||||||
|
# last thread which will be returned.
|
||||||
|
next_token = None
|
||||||
|
if len(thread_ids) > limit:
|
||||||
|
last_topo_id = rows[-2][1]
|
||||||
|
last_stream_id = rows[-2][2]
|
||||||
|
next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
|
||||||
|
|
||||||
|
return thread_ids[:limit], next_token
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_thread_id(self, event_id: str) -> str:
|
async def get_thread_id(self, event_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
30
synapse/storage/schema/main/delta/73/09threads_table.sql
Normal file
30
synapse/storage/schema/main/delta/73/09threads_table.sql
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE threads (
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
-- The event ID of the root event in the thread.
|
||||||
|
thread_id TEXT NOT NULL,
|
||||||
|
-- The latest event ID and corresponding topo / stream ordering.
|
||||||
|
latest_event_id TEXT NOT NULL,
|
||||||
|
topological_ordering BIGINT NOT NULL,
|
||||||
|
stream_ordering BIGINT NOT NULL,
|
||||||
|
CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(7309, 'threads_backfill', '{}');
|
|
@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
|
||||||
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
|
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
|
||||||
related_event_id,
|
related_event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadsTestCase(BaseRelationsTestCase):
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||||
|
def test_threads(self) -> None:
|
||||||
|
"""Create threads and ensure the ordering is due to their latest event."""
|
||||||
|
# Create 2 threads.
|
||||||
|
thread_1 = self.parent_id
|
||||||
|
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
|
||||||
|
thread_2 = res["event_id"]
|
||||||
|
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||||
|
|
||||||
|
# Request the threads in the room.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_2, thread_1])
|
||||||
|
|
||||||
|
# Update the first thread, the ordering should swap.
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_1, thread_2])
|
||||||
|
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||||
|
def test_pagination(self) -> None:
|
||||||
|
"""Create threads and paginate through them."""
|
||||||
|
# Create 2 threads.
|
||||||
|
thread_1 = self.parent_id
|
||||||
|
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
|
||||||
|
thread_2 = res["event_id"]
|
||||||
|
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||||
|
|
||||||
|
# Request the threads in the room.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_2])
|
||||||
|
|
||||||
|
# Make sure next_batch has something in it that looks like it could be a
|
||||||
|
# valid token.
|
||||||
|
next_batch = channel.json_body.get("next_batch")
|
||||||
|
self.assertIsInstance(next_batch, str, channel.json_body)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_1], channel.json_body)
|
||||||
|
|
||||||
|
self.assertNotIn("next_batch", channel.json_body, channel.json_body)
|
||||||
|
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||||
|
def test_include(self) -> None:
|
||||||
|
"""Filtering threads to all or participated in should work."""
|
||||||
|
# Thread 1 has the user as the root event.
|
||||||
|
thread_1 = self.parent_id
|
||||||
|
self._send_relation(
|
||||||
|
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Thread 2 has the user replying.
|
||||||
|
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
|
||||||
|
thread_2 = res["event_id"]
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||||
|
|
||||||
|
# Thread 3 has the user not participating in.
|
||||||
|
res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
|
||||||
|
thread_3 = res["event_id"]
|
||||||
|
self._send_relation(
|
||||||
|
RelationTypes.THREAD,
|
||||||
|
"m.room.test",
|
||||||
|
access_token=self.user2_token,
|
||||||
|
parent_id=thread_3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All threads in the room.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(
|
||||||
|
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only participated threads.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
|
||||||
|
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||||
|
def test_ignored_user(self) -> None:
|
||||||
|
"""Events from ignored users should be ignored."""
|
||||||
|
# Thread 1 has a reply from an ignored user.
|
||||||
|
thread_1 = self.parent_id
|
||||||
|
self._send_relation(
|
||||||
|
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Thread 2 is created by an ignored user.
|
||||||
|
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
|
||||||
|
thread_2 = res["event_id"]
|
||||||
|
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||||
|
|
||||||
|
# Ignore user2.
|
||||||
|
self.get_success(
|
||||||
|
self.store.add_account_data_for_user(
|
||||||
|
self.user_id,
|
||||||
|
AccountDataTypes.IGNORED_USER_LIST,
|
||||||
|
{"ignored_users": {self.user2_id: {}}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only thread 1 is returned.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||||
|
access_token=self.user_token,
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||||
|
self.assertEqual(thread_roots, [thread_1], channel.json_body)
|
||||||
|
|
Loading…
Reference in a new issue