forked from MirrorHub/synapse
Add a thread relation type per MSC3440. (#11088)
Adds experimental support for MSC3440's `io.element.thread` relation type (and the aggregation for it).
This commit is contained in:
parent
2d91b6256e
commit
ba00e20234
8 changed files with 119 additions and 8 deletions
1
changelog.d/11088.feature
Normal file
1
changelog.d/11088.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
|
|
@ -176,6 +176,7 @@ class RelationTypes:
|
||||||
ANNOTATION = "m.annotation"
|
ANNOTATION = "m.annotation"
|
||||||
REPLACE = "m.replace"
|
REPLACE = "m.replace"
|
||||||
REFERENCE = "m.reference"
|
REFERENCE = "m.reference"
|
||||||
|
THREAD = "io.element.thread"
|
||||||
|
|
||||||
|
|
||||||
class LimitBlockingTypes:
|
class LimitBlockingTypes:
|
||||||
|
|
|
@ -26,6 +26,8 @@ class ExperimentalConfig(Config):
|
||||||
|
|
||||||
# Whether to enable experimental MSC1849 (aka relations) support
|
# Whether to enable experimental MSC1849 (aka relations) support
|
||||||
self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
|
self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
|
||||||
|
# MSC3440 (thread relation)
|
||||||
|
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
|
||||||
|
|
||||||
# MSC3026 (busy presence state)
|
# MSC3026 (busy presence state)
|
||||||
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
|
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
|
||||||
|
|
|
@ -386,6 +386,7 @@ class EventClientSerializer:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self._msc1849_enabled = hs.config.experimental.msc1849_enabled
|
self._msc1849_enabled = hs.config.experimental.msc1849_enabled
|
||||||
|
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
|
||||||
|
|
||||||
async def serialize_event(
|
async def serialize_event(
|
||||||
self,
|
self,
|
||||||
|
@ -462,6 +463,22 @@ class EventClientSerializer:
|
||||||
"sender": edit.sender,
|
"sender": edit.sender,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# If this event is the start of a thread, include a summary of the replies.
|
||||||
|
if self._msc3440_enabled:
|
||||||
|
(
|
||||||
|
thread_count,
|
||||||
|
latest_thread_event,
|
||||||
|
) = await self.store.get_thread_summary(event_id)
|
||||||
|
if latest_thread_event:
|
||||||
|
r = serialized_event["unsigned"].setdefault("m.relations", {})
|
||||||
|
r[RelationTypes.THREAD] = {
|
||||||
|
# Don't bundle aggregations as this could recurse forever.
|
||||||
|
"latest_event": await self.serialize_event(
|
||||||
|
latest_thread_event, time_now, bundle_aggregations=False
|
||||||
|
),
|
||||||
|
"count": thread_count,
|
||||||
|
}
|
||||||
|
|
||||||
return serialized_event
|
return serialized_event
|
||||||
|
|
||||||
async def serialize_events(
|
async def serialize_events(
|
||||||
|
|
|
@ -128,9 +128,10 @@ class RelationSendServlet(RestServlet):
|
||||||
|
|
||||||
content["m.relates_to"] = {
|
content["m.relates_to"] = {
|
||||||
"event_id": parent_id,
|
"event_id": parent_id,
|
||||||
"key": aggregation_key,
|
|
||||||
"rel_type": relation_type,
|
"rel_type": relation_type,
|
||||||
}
|
}
|
||||||
|
if aggregation_key is not None:
|
||||||
|
content["m.relates_to"]["key"] = aggregation_key
|
||||||
|
|
||||||
event_dict = {
|
event_dict = {
|
||||||
"type": event_type,
|
"type": event_type,
|
||||||
|
|
|
@ -1710,6 +1710,7 @@ class PersistEventsStore:
|
||||||
RelationTypes.ANNOTATION,
|
RelationTypes.ANNOTATION,
|
||||||
RelationTypes.REFERENCE,
|
RelationTypes.REFERENCE,
|
||||||
RelationTypes.REPLACE,
|
RelationTypes.REPLACE,
|
||||||
|
RelationTypes.THREAD,
|
||||||
):
|
):
|
||||||
# Unknown relation type
|
# Unknown relation type
|
||||||
return
|
return
|
||||||
|
@ -1740,6 +1741,9 @@ class PersistEventsStore:
|
||||||
if rel_type == RelationTypes.REPLACE:
|
if rel_type == RelationTypes.REPLACE:
|
||||||
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
|
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
|
||||||
|
|
||||||
|
if rel_type == RelationTypes.THREAD:
|
||||||
|
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
|
||||||
|
|
||||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||||
"""Handles keeping track of insertion events and edges/connections.
|
"""Handles keeping track of insertion events and edges/connections.
|
||||||
Part of MSC2716.
|
Part of MSC2716.
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return await self.get_event(edit_id, allow_none=True)
|
return await self.get_event(edit_id, allow_none=True)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def get_thread_summary(
|
||||||
|
self, event_id: str
|
||||||
|
) -> Tuple[int, Optional[EventBase]]:
|
||||||
|
"""Get the number of threaded replies, the senders of those replies, and
|
||||||
|
the latest reply (if any) for the given event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The original event ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of items in the thread and the most recent response, if any.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
|
||||||
|
# Fetch the count of threaded events and the latest event ID.
|
||||||
|
# TODO Should this only allow m.room.message events.
|
||||||
|
sql = """
|
||||||
|
SELECT event_id
|
||||||
|
FROM event_relations
|
||||||
|
INNER JOIN events USING (event_id)
|
||||||
|
WHERE
|
||||||
|
relates_to_id = ?
|
||||||
|
AND relation_type = ?
|
||||||
|
ORDER BY topological_ordering DESC, stream_ordering DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return 0, None
|
||||||
|
|
||||||
|
latest_event_id = row[0]
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT COALESCE(COUNT(event_id), 0)
|
||||||
|
FROM event_relations
|
||||||
|
WHERE
|
||||||
|
relates_to_id = ?
|
||||||
|
AND relation_type = ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||||
|
count = txn.fetchone()[0]
|
||||||
|
|
||||||
|
return count, latest_event_id
|
||||||
|
|
||||||
|
count, latest_event_id = await self.db_pool.runInteraction(
|
||||||
|
"get_thread_summary", _get_thread_summary_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
latest_event = None
|
||||||
|
if latest_event_id:
|
||||||
|
latest_event = await self.get_event(latest_event_id, allow_none=True)
|
||||||
|
|
||||||
|
return count, latest_event
|
||||||
|
|
||||||
async def has_user_annotated_event(
|
async def has_user_annotated_event(
|
||||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_basic_paginate_relations(self):
|
def test_basic_paginate_relations(self):
|
||||||
"""Tests that calling pagination API correctly the latest relations."""
|
"""Tests that calling pagination API correctly the latest relations."""
|
||||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
|
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
|
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
annotation_id = channel.json_body["event_id"]
|
annotation_id = channel.json_body["event_id"]
|
||||||
|
|
||||||
|
@ -141,8 +141,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expected_event_ids = []
|
expected_event_ids = []
|
||||||
for _ in range(10):
|
for idx in range(10):
|
||||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
|
channel = self._send_relation(
|
||||||
|
RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
|
||||||
|
)
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
expected_event_ids.append(channel.json_body["event_id"])
|
expected_event_ids.append(channel.json_body["event_id"])
|
||||||
|
|
||||||
|
@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEquals(400, channel.code, channel.json_body)
|
self.assertEquals(400, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
|
||||||
def test_aggregation_get_event(self):
|
def test_aggregation_get_event(self):
|
||||||
"""Test that annotations and references get correctly bundled when
|
"""Test that annotations, references, and threads get correctly bundled when
|
||||||
getting the parent event.
|
getting the parent event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
reply_2 = channel.json_body["event_id"]
|
reply_2 = channel.json_body["event_id"]
|
||||||
|
|
||||||
|
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
thread_2 = channel.json_body["event_id"]
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/rooms/%s/event/%s" % (self.room, self.parent_id),
|
"/rooms/%s/event/%s" % (self.room, self.parent_id),
|
||||||
|
@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
RelationTypes.REFERENCE: {
|
RelationTypes.REFERENCE: {
|
||||||
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
|
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
|
||||||
},
|
},
|
||||||
|
RelationTypes.THREAD: {
|
||||||
|
"count": 2,
|
||||||
|
"latest_event": {
|
||||||
|
"age": 100,
|
||||||
|
"content": {
|
||||||
|
"m.relates_to": {
|
||||||
|
"event_id": self.parent_id,
|
||||||
|
"rel_type": RelationTypes.THREAD,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"event_id": thread_2,
|
||||||
|
"origin_server_ts": 1600,
|
||||||
|
"room_id": self.room,
|
||||||
|
"sender": self.user_id,
|
||||||
|
"type": "m.room.test",
|
||||||
|
"unsigned": {"age": 100},
|
||||||
|
"user_id": self.user_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
{
|
{
|
||||||
"m.relates_to": {
|
"m.relates_to": {
|
||||||
"event_id": self.parent_id,
|
"event_id": self.parent_id,
|
||||||
"key": None,
|
|
||||||
"rel_type": "m.reference",
|
"rel_type": "m.reference",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in a new issue