Add an API for listing threads in a room. (#13394)

Implement the /threads endpoint from MSC3856.

This is currently unstable and behind an experimental configuration
flag.

It includes a background update to backfill data, results from
the /threads endpoint will be partial until that finishes.
This commit is contained in:
Patrick Cloke 2022-10-13 08:02:11 -04:00 committed by GitHub
parent b6baa46db0
commit 3bbe532abb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 522 additions and 6 deletions

View file

@ -0,0 +1 @@
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.

View file

@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
) )
from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
@ -206,6 +207,7 @@ class Store(
PusherWorkerStore, PusherWorkerStore,
PresenceBackgroundUpdateStore, PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore, ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
): ):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View file

@ -101,6 +101,9 @@ class ExperimentalConfig(Config):
# MSC3848: Introduce errcodes for specific event sending failures # MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
# MSC3856: Threads list API
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import logging import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import _RelatedEvent from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -32,6 +33,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
all = "all"
participated = "participated"
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation: class _ThreadAggregation:
# The latest event in the thread. # The latest event in the thread.
@ -482,3 +490,79 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit results.setdefault(event_id, BundledAggregations()).replace = edit
return results return results
async def get_threads(
self,
requester: Requester,
room_id: str,
include: ThreadsListInclude,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
Args:
requester: The user requesting the relations.
room_id: The room the event belongs to.
include: One of "all" or "participated" to indicate which threads should
be returned.
limit: Only fetch the most recent `limit` events.
from_token: Fetch rows from the given token, or from the start if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, requester, allow_departed_users=True
)
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
room_id=room_id, limit=limit, from_token=from_token
)
events = await self._main_store.get_events_as_list(thread_roots)
if include == ThreadsListInclude.participated:
# Pre-seed thread participation with whether the requester sent the event.
participated = {event.event_id: event.sender == user_id for event in events}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
)
# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]
events = await filter_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
now = self._clock.time_msec()
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value: JsonDict = {"chunk": serialized_events}
if next_batch:
return_value["next_batch"] = str(next_batch)
return return_value

View file

@ -13,12 +13,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict from synapse.types import JsonDict
@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet):
return 200, result return 200, result
class ThreadsServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
include = parse_string(
request,
"include",
default=ThreadsListInclude.all.value,
allowed_values=[v.value for v in ThreadsListInclude],
)
# Return the relations
from_token = None
if from_token_str:
from_token = ThreadsNextBatch.from_string(from_token_str)
result = await self._relations_handler.get_threads(
requester=requester,
room_id=room_id,
include=ThreadsListInclude(include),
limit=limit,
from_token=from_token,
)
return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server)
if hs.config.experimental.msc3856_enabled:
ThreadsServlet(hs).register(http_server)

View file

@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
self._attempt_to_invalidate_cache("get_threads", (room_id,))
async def invalidate_cache_and_stream( async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...] self, cache_name: str, keys: Tuple[Any, ...]

View file

@ -35,7 +35,7 @@ import attr
from prometheus_client import Counter from prometheus_client import Counter
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
@ -1616,7 +1616,7 @@ class PersistEventsStore:
) )
# Remove from relations table. # Remove from relations table.
self._handle_redact_relations(txn, event.redacts) self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and # Update the event_forward_extremities, event_backward_extremities and
# event_edges tables. # event_edges tables.
@ -1866,6 +1866,34 @@ class PersistEventsStore:
}, },
) )
if relation.rel_type == RelationTypes.THREAD:
# Upsert into the threads table, but only overwrite the value if the
# new event is of a later topological order OR if the topological
# ordering is equal, but the stream ordering is later.
sql = """
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (room_id, thread_id)
DO UPDATE SET
latest_event_id = excluded.latest_event_id,
topological_ordering = excluded.topological_ordering,
stream_ordering = excluded.stream_ordering
WHERE
threads.topological_ordering <= excluded.topological_ordering AND
threads.stream_ordering < excluded.stream_ordering
"""
txn.execute(
sql,
(
event.room_id,
relation.parent_id,
event.event_id,
event.depth,
event.internal_metadata.stream_ordering,
),
)
def _handle_insertion_event( def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase self, txn: LoggingTransaction, event: EventBase
) -> None: ) -> None:
@ -1989,13 +2017,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,)) txn.execute(sql, (batch_id,))
def _handle_redact_relations( def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None: ) -> None:
"""Handles receiving a redaction and checking whether the redacted event """Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database. has any relations which must be removed from the database.
Args: Args:
txn txn
room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted. redacted_event_id: The event that was redacted.
""" """
@ -2024,6 +2053,9 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,) txn, self.store.get_thread_participated, (redacted_relates_to,)
) )
self.store._invalidate_cache_and_stream(
txn, self.store.get_threads, (room_id,)
)
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}

View file

@ -14,6 +14,7 @@
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Collection, Collection,
Dict, Dict,
FrozenSet, FrozenSet,
@ -29,17 +30,46 @@ from typing import (
import attr import attr
from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadsNextBatch:
topological_ordering: int
stream_ordering: int
def __str__(self) -> str:
return f"{self.topological_ordering}_{self.stream_ordering}"
@classmethod
def from_string(cls, string: str) -> "ThreadsNextBatch":
"""
Creates a ThreadsNextBatch from its textual representation.
"""
try:
keys = (int(s) for s in string.split("_"))
return cls(*keys)
except Exception:
raise SynapseError(400, "Invalid threads token")
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent: class _RelatedEvent:
""" """
@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore): class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
"threads_backfill", self._backfill_threads
)
async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
"""Backfill the threads table."""
def threads_backfill_txn(txn: LoggingTransaction) -> int:
last_thread_id = progress.get("last_thread_id", "")
# Get the latest event in each thread by topo ordering / stream ordering.
#
# Note that the MAX(event_id) is needed to abide by the rules of group by,
# but doesn't actually do anything since there should only be a single event
# ID per topo/stream ordering pair.
sql = f"""
SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id > ? AND
relation_type = '{RelationTypes.THREAD}'
GROUP BY room_id, relates_to_id
ORDER BY relates_to_id
LIMIT ?
"""
txn.execute(sql, (last_thread_id, batch_size))
# No more rows to process.
rows = txn.fetchall()
if not rows:
return 0
# Insert the rows into the threads table. If a matching thread already exists,
# assume it is from a newer event.
sql = """
INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
VALUES %s
ON CONFLICT (room_id, thread_id)
DO NOTHING
"""
if isinstance(txn.database_engine, PostgresEngine):
txn.execute_values(sql % ("?",), rows, fetch=False)
else:
txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
# Mark the progress.
self.db_pool.updates._background_update_progress_txn(
txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
)
return txn.rowcount
result = await self.db_pool.runInteraction(
"threads_backfill", threads_backfill_txn
)
if not result:
await self.db_pool.updates._end_background_update("threads_backfill")
return result
@cached(uncached_args=("event",), tree=True) @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event( async def get_relations_for_event(
self, self,
@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )
@cached(tree=True)
async def get_threads(
self,
room_id: str,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
Args:
room_id: The room the event belongs to.
limit: Only fetch the most recent `limit` threads.
from_token: Fetch rows from a previous next_batch, or from the start if None.
Returns:
A tuple of:
A list of thread root event IDs.
The next_batch, if one exists.
"""
# Generate the pagination clause, if necessary.
#
# Find any threads where the latest reply is equal / before the last
# thread's topo ordering and earlier in stream ordering.
pagination_clause = ""
pagination_args: tuple = ()
if from_token:
pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
pagination_args = (
from_token.topological_ordering,
from_token.stream_ordering,
)
sql = f"""
SELECT thread_id, topological_ordering, stream_ordering
FROM threads
WHERE
room_id = ?
{pagination_clause}
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT ?
"""
def _get_threads_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
txn.execute(sql, (room_id, *pagination_args, limit + 1))
rows = cast(List[Tuple[str, int, int]], txn.fetchall())
thread_ids = [r[0] for r in rows]
# If there are more events, generate the next pagination key from the
# last thread which will be returned.
next_token = None
if len(thread_ids) > limit:
last_topo_id = rows[-2][1]
last_stream_id = rows[-2][2]
next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
return thread_ids[:limit], next_token
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
@cached() @cached()
async def get_thread_id(self, event_id: str) -> str: async def get_thread_id(self, event_id: str) -> str:
""" """

View file

@ -0,0 +1,30 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE threads (
room_id TEXT NOT NULL,
-- The event ID of the root event in the thread.
thread_id TEXT NOT NULL,
-- The latest event ID and corresponding topo / stream ordering.
latest_event_id TEXT NOT NULL,
topological_ordering BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL,
CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
);
CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7309, 'threads_backfill', '{}');

View file

@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"], relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id, related_event_id,
) )
class ThreadsTestCase(BaseRelationsTestCase):
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_threads(self) -> None:
"""Create threads and ensure the ordering is due to their latest event."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1])
# Update the first thread, the ordering should swap.
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1, thread_2])
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_pagination(self) -> None:
"""Create threads and paginate through them."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2])
# Make sure next_batch has something in it that looks like it could be a
# valid token.
next_batch = channel.json_body.get("next_batch")
self.assertIsInstance(next_batch, str, channel.json_body)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
self.assertNotIn("next_batch", channel.json_body, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_include(self) -> None:
"""Filtering threads to all or participated in should work."""
# Thread 1 has the user as the root event.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 has the user replying.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Thread 3 has the user not participating in.
res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
thread_3 = res["event_id"]
self._send_relation(
RelationTypes.THREAD,
"m.room.test",
access_token=self.user2_token,
parent_id=thread_3,
)
# All threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
)
# Only participated threads.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_ignored_user(self) -> None:
"""Events from ignored users should be ignored."""
# Thread 1 has a reply from an ignored user.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 is created by an ignored user.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Ignore user2.
self.get_success(
self.store.add_account_data_for_user(
self.user_id,
AccountDataTypes.IGNORED_USER_LIST,
{"ignored_users": {self.user2_id: {}}},
)
)
# Only thread 1 is returned.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)