forked from MirrorHub/synapse
Include whether the requesting user has participated in a thread. (#11577)
Per updates to MSC3440. This is implement as a separate method since it needs to be cached on a per-user basis, instead of a per-thread basis.
This commit is contained in:
parent
251b5567ec
commit
68acb0a29d
9 changed files with 86 additions and 19 deletions
1
changelog.d/11577.feature
Normal file
1
changelog.d/11577.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
|
|
@ -537,7 +537,7 @@ class PaginationHandler:
|
|||
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||
state = state_dict.values()
|
||||
|
||||
aggregations = await self.store.get_bundled_aggregations(events)
|
||||
aggregations = await self.store.get_bundled_aggregations(events, user_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -1182,12 +1182,18 @@ class RoomContextHandler:
|
|||
results["event"] = filtered[0]
|
||||
|
||||
# Fetch the aggregations.
|
||||
aggregations = await self.store.get_bundled_aggregations([results["event"]])
|
||||
aggregations.update(
|
||||
await self.store.get_bundled_aggregations(results["events_before"])
|
||||
aggregations = await self.store.get_bundled_aggregations(
|
||||
[results["event"]], user.to_string()
|
||||
)
|
||||
aggregations.update(
|
||||
await self.store.get_bundled_aggregations(results["events_after"])
|
||||
await self.store.get_bundled_aggregations(
|
||||
results["events_before"], user.to_string()
|
||||
)
|
||||
)
|
||||
aggregations.update(
|
||||
await self.store.get_bundled_aggregations(
|
||||
results["events_after"], user.to_string()
|
||||
)
|
||||
)
|
||||
results["aggregations"] = aggregations
|
||||
|
||||
|
|
|
@ -637,7 +637,9 @@ class SyncHandler:
|
|||
# as clients will have all the necessary information.
|
||||
bundled_aggregations = None
|
||||
if limited or newly_joined_room:
|
||||
bundled_aggregations = await self.store.get_bundled_aggregations(recents)
|
||||
bundled_aggregations = await self.store.get_bundled_aggregations(
|
||||
recents, sync_config.user.to_string()
|
||||
)
|
||||
|
||||
return TimelineBatch(
|
||||
events=recents,
|
||||
|
|
|
@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
|
|||
)
|
||||
# The relations returned for the requested event do include their
|
||||
# bundled aggregations.
|
||||
aggregations = await self.store.get_bundled_aggregations(events)
|
||||
aggregations = await self.store.get_bundled_aggregations(
|
||||
events, requester.user.to_string()
|
||||
)
|
||||
serialized_events = self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
|
|
@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
|
|||
|
||||
if event:
|
||||
# Ensure there are bundled aggregations available.
|
||||
aggregations = await self._store.get_bundled_aggregations([event])
|
||||
aggregations = await self._store.get_bundled_aggregations(
|
||||
[event], requester.user.to_string()
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
event_dict = self._event_serializer.serialize_event(
|
||||
|
|
|
@ -1793,6 +1793,13 @@ class PersistEventsStore:
|
|||
txn.call_after(
|
||||
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
|
||||
)
|
||||
# It should be safe to only invalidate the cache if the user has not
|
||||
# previously participated in the thread, but that's difficult (and
|
||||
# potentially error-prone) so it is always invalidated.
|
||||
txn.call_after(
|
||||
self.store.get_thread_participated.invalidate,
|
||||
(parent_id, event.room_id, event.sender),
|
||||
)
|
||||
|
||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
"""Handles keeping track of insertion events and edges/connections.
|
||||
|
|
|
@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
async def get_thread_summary(
|
||||
self, event_id: str, room_id: str
|
||||
) -> Tuple[int, Optional[EventBase]]:
|
||||
"""Get the number of threaded replies, the senders of those replies, and
|
||||
the latest reply (if any) for the given event.
|
||||
"""Get the number of threaded replies and the latest reply (if any) for the given event.
|
||||
|
||||
Args:
|
||||
event_id: Summarize the thread related to this event ID.
|
||||
|
@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
def _get_thread_summary_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
# Fetch the count of threaded events and the latest event ID.
|
||||
# Fetch the latest event ID in the thread.
|
||||
# TODO Should this only allow m.room.message events.
|
||||
sql = """
|
||||
SELECT event_id
|
||||
|
@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
latest_event_id = row[0]
|
||||
|
||||
# Fetch the number of threaded replies.
|
||||
sql = """
|
||||
SELECT COUNT(event_id)
|
||||
FROM event_relations
|
||||
|
@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return count, latest_event
|
||||
|
||||
@cached()
|
||||
async def get_thread_participated(
|
||||
self, event_id: str, room_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""Get whether the requesting user participated in a thread.
|
||||
|
||||
This is separate from get_thread_summary since that can be cached across
|
||||
all users while this value is specific to the requeser.
|
||||
|
||||
Args:
|
||||
event_id: The thread related to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
user_id: The user requesting the summary.
|
||||
|
||||
Returns:
|
||||
True if the requesting user participated in the thread, otherwise false.
|
||||
"""
|
||||
|
||||
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
|
||||
# Fetch whether the requester has participated or not.
|
||||
sql = """
|
||||
SELECT 1
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND room_id = ?
|
||||
AND relation_type = ?
|
||||
AND sender = ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_thread_summary", _get_thread_summary_txn
|
||||
)
|
||||
|
||||
async def events_have_relations(
|
||||
self,
|
||||
parent_ids: List[str],
|
||||
|
@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def _get_bundled_aggregation_for_event(
|
||||
self, event: EventBase
|
||||
self, event: EventBase, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generate bundled aggregations for an event.
|
||||
|
||||
|
@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
Args:
|
||||
event: The event to calculate bundled aggregations for.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
|
||||
Returns:
|
||||
The bundled aggregations for an event, if bundled aggregations are
|
||||
|
@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
# If this event is the start of a thread, include a summary of the replies.
|
||||
if self._msc3440_enabled:
|
||||
(
|
||||
thread_count,
|
||||
latest_thread_event,
|
||||
) = await self.get_thread_summary(event_id, room_id)
|
||||
thread_count, latest_thread_event = await self.get_thread_summary(
|
||||
event_id, room_id
|
||||
)
|
||||
participated = await self.get_thread_participated(
|
||||
event_id, room_id, user_id
|
||||
)
|
||||
if latest_thread_event:
|
||||
aggregations[RelationTypes.THREAD] = {
|
||||
# Don't bundle aggregations as this could recurse forever.
|
||||
"latest_event": latest_thread_event,
|
||||
"count": thread_count,
|
||||
"current_user_participated": participated,
|
||||
}
|
||||
|
||||
# Store the bundled aggregations in the event metadata for later use.
|
||||
return aggregations
|
||||
|
||||
async def get_bundled_aggregations(
|
||||
self, events: Iterable[EventBase]
|
||||
self,
|
||||
events: Iterable[EventBase],
|
||||
user_id: str,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Generate bundled aggregations for events.
|
||||
|
||||
Args:
|
||||
events: The iterable of events to calculate bundled aggregations for.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
|
||||
Returns:
|
||||
A map of event ID to the bundled aggregation for the event. Not all
|
||||
|
@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
# TODO Parallelize.
|
||||
results = {}
|
||||
for event in events:
|
||||
event_result = await self._get_bundled_aggregation_for_event(event)
|
||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||
if event_result is not None:
|
||||
results[event.event_id] = event_result
|
||||
|
||||
|
|
|
@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||
2,
|
||||
actual[RelationTypes.THREAD].get("count"),
|
||||
)
|
||||
self.assertTrue(
|
||||
actual[RelationTypes.THREAD].get("current_user_participated")
|
||||
)
|
||||
# The latest thread event has some fields that don't matter.
|
||||
self.assert_dict(
|
||||
{
|
||||
|
|
Loading…
Reference in a new issue