Ensure forked threads are not allowed.
This commit is contained in:
parent
1fff047b38
commit
a7503470d7
|
@ -1048,6 +1048,13 @@ class EventCreationHandler:
|
|||
if already_exists:
|
||||
raise SynapseError(400, "Can't send same reaction twice")
|
||||
|
||||
# If this relation is a thread, then ensure thread head is not part of
|
||||
# a thread already.
|
||||
elif relation_type == RelationTypes.THREAD:
|
||||
already_thread = await self.store.get_event_thread(relates_to)
|
||||
if already_thread:
|
||||
raise SynapseError(400, "Can't fork threads")
|
||||
|
||||
@measure_func("handle_new_client_event")
|
||||
async def handle_new_client_event(
|
||||
self,
|
||||
|
|
|
@ -372,6 +372,37 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||
)
|
||||
|
||||
async def get_event_thread(self, event_id: str) -> Optional[str]:
|
||||
"""Return an event's thread.
|
||||
|
||||
Args:
|
||||
event_id: The event being used as the start of a new thread.
|
||||
|
||||
Returns:
|
||||
The thread ID of the event.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT relates_to_id FROM event_relations
|
||||
WHERE
|
||||
event_id = ?
|
||||
AND relation_type = ?
|
||||
LIMIT 1;
|
||||
"""
|
||||
|
||||
def _get_thread_id(txn) -> Optional[str]:
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event_id,
|
||||
RelationTypes.THREAD,
|
||||
),
|
||||
)
|
||||
|
||||
return txn.fetchone()
|
||||
|
||||
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
|
||||
|
||||
|
||||
class RelationsStore(RelationsWorkerStore):
|
||||
pass
|
||||
|
|
|
@ -119,6 +119,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
|
||||
self.assertEquals(400, channel.code, channel.json_body)
|
||||
|
||||
def test_deny_forked_thread(self):
|
||||
"""It is invalid to start a thread off a thread."""
|
||||
channel = self._send_relation(
|
||||
RelationTypes.THREAD,
|
||||
"m.room.message",
|
||||
content={"msgtype": "m.text", "body": "foo"},
|
||||
parent_id=self.parent_id,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
parent_id = channel.json_body["event_id"]
|
||||
|
||||
channel = self._send_relation(
|
||||
RelationTypes.THREAD,
|
||||
"m.room.message",
|
||||
content={"msgtype": "m.text", "body": "foo"},
|
||||
parent_id=parent_id,
|
||||
)
|
||||
self.assertEquals(400, channel.code, channel.json_body)
|
||||
|
||||
def test_basic_paginate_relations(self):
|
||||
"""Tests that calling pagination API correctly the latest relations."""
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
|
||||
|
|
Loading…
Reference in a new issue