0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-17 18:18:58 +02:00

Recursively fetch the thread for receipts & notifications. (#13824)

Consider an event to be part of a thread if you can follow a
chain of relations up to a thread root.

Part of MSC3773 & MSC3771.
This commit is contained in:
Patrick Cloke 2022-10-04 11:36:16 -04:00 committed by GitHub
parent 3e74ad20db
commit 2b6d41ebd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 2 deletions

View file

@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

View file

@ -286,8 +286,13 @@ class BulkPushRuleEvaluator:
relation.parent_id,
itertools.chain(*(r.rules() for r in rules_by_user.values())),
)
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
evaluator = PushRuleEvaluator(
_flatten_dict(event),

View file

@ -16,7 +16,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
self._main_store = hs.get_datastores().main
self._known_receipt_types = {
ReceiptTypes.READ,
@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet):
thread_id = body.get("thread_id")
if not thread_id or not isinstance(thread_id, str):
raise SynapseError(
400, "thread_id field must be a non-empty string"
400,
"thread_id field must be a non-empty string",
Codes.INVALID_PARAM,
)
if receipt_type == ReceiptTypes.FULLY_READ:
raise SynapseError(
400,
f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
Codes.INVALID_PARAM,
)
# Ensure the event ID roughly correlates to the thread ID.
if thread_id != await self._main_store.get_thread_id(event_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
Codes.INVALID_PARAM,
)
await self.presence_handler.bump_presence_active_time(requester.user)

View file

@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
@cached()
async def get_thread_id(self, event_id: str) -> Optional[str]:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
of a thread. None, otherwise.
"""
# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""
def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
return None
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
class RelationsStore(RelationsWorkerStore):
pass

View file

@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate()
_assert_counts(0, 0, 0, 0)
def test_recursive_thread(self) -> None:
"""
Events related to events in a thread should still be considered part of
that thread.
"""
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")
# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
# Update the user's push rules to care about reaction events.
self.get_success(
self.store.add_push_rule(
user_id,
"related_events",
priority_class=5,
conditions=[
{"kind": "event_match", "key": "type", "pattern": "m.reaction"}
],
actions=["notify"],
)
)
def _create_event(type: str, content: JsonDict) -> str:
result = self.helper.send_event(
room_id, type=type, content=content, tok=other_token
)
return result["event_id"]
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
self.assertEqual(
counts.threads,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=0,
highlight_count=0,
),
},
)
else:
self.assertEqual(counts.threads, {})
# Create a root event.
thread_id = _create_event(
"m.room.message", {"msgtype": "m.text", "body": "msg"}
)
_assert_counts(1, 0)
# Reply, creating a thread.
reply_id = _create_event(
"m.room.message",
{
"msgtype": "m.text",
"body": "msg",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": thread_id,
},
},
)
_assert_counts(1, 1)
# Create an event related to a thread event, this should still appear in
# the thread.
_create_event(
type="m.reaction",
content={
"m.relates_to": {
"rel_type": "m.annotation",
"event_id": reply_id,
"key": "A",
}
},
)
_assert_counts(1, 2)
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(