Support stable identifiers for MSC2285: private read receipts. (#13273)

This adds support for the stable identifiers of MSC2285 while
continuing to support the unstable identifiers behind the configuration
flag. These will be removed in a future version.
This commit is contained in:
Šimon Brandner 2022-08-05 17:09:33 +02:00 committed by GitHub
parent e2ed1b7155
commit ab18441573
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 246 additions and 94 deletions

View file

@ -0,0 +1 @@
Add support for stable prefixes for [MSC2285 (private read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285).

View file

@ -257,7 +257,8 @@ class GuestAccess:
class ReceiptTypes: class ReceiptTypes:
READ: Final = "m.read" READ: Final = "m.read"
READ_PRIVATE: Final = "org.matrix.msc2285.read.private" READ_PRIVATE: Final = "m.read.private"
UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
FULLY_READ: Final = "m.fully_read" FULLY_READ: Final = "m.fully_read"

View file

@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
# MSC2716 (importing historical messages) # MSC2716 (importing historical messages)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
# MSC2285 (private read receipts) # MSC2285 (unstable private read receipts)
self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
# MSC3244 (room version capabilities) # MSC3244 (room version capabilities)

View file

@ -143,8 +143,8 @@ class InitialSyncHandler:
joined_rooms, joined_rooms,
to_key=int(now_token.receipt_key), to_key=int(now_token.receipt_key),
) )
if self.hs.config.experimental.msc2285_enabled:
receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
@ -456,11 +456,8 @@ class InitialSyncHandler:
) )
if not receipts: if not receipts:
return [] return []
if self.hs.config.experimental.msc2285_enabled:
receipts = ReceiptEventSource.filter_out_private_receipts( return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
receipts, user_id
)
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable( presence, receipts, (messages, token) = await make_deferred_yieldable(
gather_results( gather_results(

View file

@ -163,7 +163,10 @@ class ReceiptsHandler:
if not is_new: if not is_new:
return return
if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: if self.federation_sender and receipt_type not in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
):
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
@ -203,24 +206,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
for event_id, orig_event_content in room.get("content", {}).items(): for event_id, orig_event_content in room.get("content", {}).items():
event_content = orig_event_content event_content = orig_event_content
# If there are private read receipts, additional logic is necessary. # If there are private read receipts, additional logic is necessary.
if ReceiptTypes.READ_PRIVATE in event_content: if (
ReceiptTypes.READ_PRIVATE in event_content
or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content
):
# Make a copy without private read receipts to avoid leaking # Make a copy without private read receipts to avoid leaking
# other user's private read receipts.. # other user's private read receipts..
event_content = { event_content = {
receipt_type: receipt_value receipt_type: receipt_value
for receipt_type, receipt_value in event_content.items() for receipt_type, receipt_value in event_content.items()
if receipt_type != ReceiptTypes.READ_PRIVATE if receipt_type
not in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
)
} }
# Copy the current user's private read receipt from the # Copy the current user's private read receipt from the
# original content, if it exists. # original content, if it exists.
user_private_read_receipt = orig_event_content[ user_private_read_receipt = orig_event_content.get(
ReceiptTypes.READ_PRIVATE ReceiptTypes.READ_PRIVATE, {}
].get(user_id, None) ).get(user_id, None)
if user_private_read_receipt: if user_private_read_receipt:
event_content[ReceiptTypes.READ_PRIVATE] = { event_content[ReceiptTypes.READ_PRIVATE] = {
user_id: user_private_read_receipt user_id: user_private_read_receipt
} }
user_unstable_private_read_receipt = orig_event_content.get(
ReceiptTypes.UNSTABLE_READ_PRIVATE, {}
).get(user_id, None)
if user_unstable_private_read_receipt:
event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = {
user_id: user_unstable_private_read_receipt
}
# Include the event if there is at least one non-private read # Include the event if there is at least one non-private read
# receipt or the current user has a private read receipt. # receipt or the current user has a private read receipt.
@ -256,10 +273,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
room_ids, from_key=from_key, to_key=to_key room_ids, from_key=from_key, to_key=to_key
) )
if self.config.experimental.msc2285_enabled: events = ReceiptEventSource.filter_out_private_receipts(
events = ReceiptEventSource.filter_out_private_receipts( events, user.to_string()
events, user.to_string() )
)
return events, to_key return events, to_key

View file

@ -416,7 +416,10 @@ class FederationSenderHandler:
if not self._is_mine_id(receipt.user_id): if not self._is_mine_id(receipt.user_id):
continue continue
# Private read receipts never get sent over federation. # Private read receipts never get sent over federation.
if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: if receipt.receipt_type in (
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
):
continue continue
receipt_info = ReadReceipt( receipt_info = ReadReceipt(
receipt.room_id, receipt.room_id,

View file

@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet):
) )
receipts_by_room = await self.store.get_receipts_for_user_with_orderings( receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] user_id,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
notif_event_ids = [pa.event_id for pa in push_actions] notif_event_ids = [pa.event_id for pa in push_actions]

View file

@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} self._known_receipt_types = {
ReceiptTypes.READ,
ReceiptTypes.FULLY_READ,
ReceiptTypes.READ_PRIVATE,
}
if hs.config.experimental.msc2285_enabled: if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str

View file

@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._known_receipt_types = {ReceiptTypes.READ} self._known_receipt_types = {
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.FULLY_READ,
}
if hs.config.experimental.msc2285_enabled: if hs.config.experimental.msc2285_enabled:
self._known_receipt_types.update( self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
)
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str

View file

@ -94,6 +94,7 @@ class VersionsRestServlet(RestServlet):
# Supports the busy presence state described in MSC3026. # Supports the busy presence state described in MSC3026.
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285 # Supports receiving private read receipts as per MSC2285
"org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec
"org.matrix.msc2285": self.config.experimental.msc2285_enabled, "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
# Supports filtering of /publicRooms by room type as per MSC3827 # Supports filtering of /publicRooms by room type as per MSC3827
"org.matrix.msc3827.stable": True, "org.matrix.msc3827.stable": True,

View file

@ -80,7 +80,7 @@ import attr
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
@ -259,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn, txn,
user_id, user_id,
room_id, room_id,
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
) )
stream_ordering = None stream_ordering = None
@ -448,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by ascending stream_ordering. The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
def get_after_receipt( def get_after_receipt(
@ -455,7 +460,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) -> List[Tuple[str, str, int, str, bool]]: ) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
sql = """
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
sql = f"""
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight ep.highlight
FROM ( FROM (
@ -463,10 +479,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
MAX(stream_ordering) as stream_ordering MAX(stream_ordering) as stream_ordering
FROM events FROM events
INNER JOIN receipts_linearized USING (room_id, event_id) INNER JOIN receipts_linearized USING (room_id, event_id)
WHERE receipt_type = 'm.read' AND user_id = ? WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id GROUP BY room_id
) AS rl, ) AS rl,
event_push_actions AS ep event_push_actions AS ep
WHERE WHERE
ep.room_id = rl.room_id ep.room_id = rl.room_id
AND ep.stream_ordering > rl.stream_ordering AND ep.stream_ordering > rl.stream_ordering
@ -476,7 +492,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND ep.notif = 1 AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ? ORDER BY ep.stream_ordering ASC LIMIT ?
""" """
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@ -490,7 +508,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt( def get_no_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]: ) -> List[Tuple[str, str, int, str, bool]]:
sql = """ receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
sql = f"""
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight ep.highlight
FROM event_push_actions AS ep FROM event_push_actions AS ep
@ -498,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
WHERE WHERE
ep.room_id NOT IN ( ep.room_id NOT IN (
SELECT room_id FROM receipts_linearized SELECT room_id FROM receipts_linearized
WHERE receipt_type = 'm.read' AND user_id = ? WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id GROUP BY room_id
) )
AND ep.user_id = ? AND ep.user_id = ?
@ -507,7 +535,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND ep.notif = 1 AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ? ORDER BY ep.stream_ordering ASC LIMIT ?
""" """
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@ -557,12 +587,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by descending received_ts. The list will be ordered by descending received_ts.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
# find rooms that have a read receipt in them and return the most recent # find rooms that have a read receipt in them and return the most recent
# push actions # push actions
def get_after_receipt( def get_after_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]: ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = """ receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
sql = f"""
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight, e.received_ts ep.highlight, e.received_ts
FROM ( FROM (
@ -570,7 +611,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
MAX(stream_ordering) as stream_ordering MAX(stream_ordering) as stream_ordering
FROM events FROM events
INNER JOIN receipts_linearized USING (room_id, event_id) INNER JOIN receipts_linearized USING (room_id, event_id)
WHERE receipt_type = 'm.read' AND user_id = ? WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id GROUP BY room_id
) AS rl, ) AS rl,
event_push_actions AS ep event_push_actions AS ep
@ -584,7 +625,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND ep.notif = 1 AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ? ORDER BY ep.stream_ordering DESC LIMIT ?
""" """
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@ -598,7 +641,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt( def get_no_receipt(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]: ) -> List[Tuple[str, str, int, str, bool, int]]:
sql = """ receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
sql = f"""
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight, e.received_ts ep.highlight, e.received_ts
FROM event_push_actions AS ep FROM event_push_actions AS ep
@ -606,7 +659,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
WHERE WHERE
ep.room_id NOT IN ( ep.room_id NOT IN (
SELECT room_id FROM receipts_linearized SELECT room_id FROM receipts_linearized
WHERE receipt_type = 'm.read' AND user_id = ? WHERE {receipt_types_clause} AND user_id = ?
GROUP BY room_id GROUP BY room_id
) )
AND ep.user_id = ? AND ep.user_id = ?
@ -615,7 +668,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND ep.notif = 1 AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ? ORDER BY ep.stream_ordering DESC LIMIT ?
""" """
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args.extend(
(user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
txn.execute(sql, args) txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())

View file

@ -15,6 +15,8 @@
from copy import deepcopy from copy import deepcopy
from typing import List from typing import List
from parameterized import parameterized
from synapse.api.constants import EduTypes, ReceiptTypes from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.types import JsonDict from synapse.types import JsonDict
@ -25,13 +27,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt self.event_source = hs.get_event_sources().sources.receipt
def test_filters_out_private_receipt(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_private_receipt(self, receipt_type: str) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1435641916114394fHBLK:matrix.org": { "$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
} }
@ -45,13 +50,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[], [],
) )
def test_filters_out_private_receipt_and_ignores_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_private_receipt_and_ignores_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -84,13 +94,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$14356419edgd14394fHBLK:matrix.org": { "$14356419edgd14394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -125,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_handles_empty_event(self): def test_handles_empty_event(self) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
@ -160,13 +175,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
self, receipt_type: str
) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$14356419edgd14394fHBLK:matrix.org": { "$14356419edgd14394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -207,7 +227,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_handles_string_data(self): def test_handles_string_data(self) -> None:
""" """
Tests that an invalid shape for read-receipts is handled. Tests that an invalid shape for read-receipts is handled.
Context: https://github.com/matrix-org/synapse/issues/10603 Context: https://github.com/matrix-org/synapse/issues/10603
@ -242,13 +262,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_leaves_our_private_and_their_public(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None:
self._test_filters_private( self._test_filters_private(
[ [
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@me:server.org": { "@me:server.org": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -273,7 +296,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{ {
"content": { "content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": { "$1dgdgrd5641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@me:server.org": { "@me:server.org": {
"ts": 1436451550453, "ts": 1436451550453,
}, },
@ -296,13 +319,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_we_do_not_mutate(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_we_do_not_mutate(self, receipt_type: str) -> None:
"""Ensure the input values are not modified.""" """Ensure the input values are not modified."""
events = [ events = [
{ {
"content": { "content": {
"$1435641916114394fHBLK:matrix.org": { "$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: { receipt_type: {
"@rikj:jki.re": { "@rikj:jki.re": {
"ts": 1436451550453, "ts": 1436451550453,
} }
@ -320,7 +346,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def _test_filters_private( def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict] self, events: List[JsonDict], expected_output: List[JsonDict]
): ) -> None:
"""Tests that the _filter_out_private returns the expected output""" """Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private_receipts( filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org" events, "@me:server.org"

View file

@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin, KnockingStrippedStateEventHelperMixin,
) )
from tests.server import TimedOutException from tests.server import TimedOutException
from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase): class FilterTestCase(unittest.HomeserverTestCase):
@ -390,6 +389,12 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["experimental_features"] = {"msc2285_enabled": True}
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s" self.url = "/sync?since=%s"
self.next_batch = "s0" self.next_batch = "s0"
@ -408,15 +413,17 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user # Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_private_read_receipts(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_private_read_receipts(self, receipt_type: str) -> None:
# Send a message as the first user # Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok) res = self.helper.send(self.room_id, body="hello", tok=self.tok)
# Send a private read receipt to tell the server the first user's message was read # Send a private read receipt to tell the server the first user's message was read
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -425,8 +432,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt # Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt()) self.assertIsNone(self._get_read_receipt())
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_public_receipt_can_override_private(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_public_receipt_can_override_private(self, receipt_type: str) -> None:
""" """
Sending a public read receipt to the same event which has a private read Sending a public read receipt to the same event which has a private read
receipt should cause that receipt to become public. receipt should cause that receipt to become public.
@ -437,7 +446,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt # Send a private read receipt
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -456,8 +465,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt # Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None) self.assertNotEqual(self._get_read_receipt(), None)
@override_config({"experimental_features": {"msc2285_enabled": True}}) @parameterized.expand(
def test_private_receipt_cannot_override_public(self) -> None: [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None:
""" """
Sending a private read receipt to the same event which has a public read Sending a private read receipt to the same event which has a public read
receipt should cause no change. receipt should cause no change.
@ -478,7 +489,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt # Send a private read receipt
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok2, access_token=self.tok2,
) )
@ -590,7 +601,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
tok=self.tok, tok=self.tok,
) )
def test_unread_counts(self) -> None: @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_unread_counts(self, receipt_type: str) -> None:
"""Tests that /sync returns the right value for the unread count (MSC2654).""" """Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count. # Check that our own messages don't increase the unread count.
@ -624,7 +638,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event. # Send a read receipt to tell the server we've read the latest event.
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )
@ -700,7 +714,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5) self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
# Make sure both m.read and org.matrix.msc2285.read.private advance # Make sure both m.read and m.read.private advance
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@ -712,16 +726,22 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0) self._check_unread_count(0)
# We test for both receipt types that influence notification counts # We test for all three receipt types that influence notification counts
@parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) @parameterized.expand(
def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
]
)
def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user # Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@ -739,11 +759,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0) self._check_unread_count(0)
# Make sure neither m.read nor org.matrix.msc2285.read.private make the # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event # read receipt go up to an older event
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}",
{}, {},
access_token=self.tok, access_token=self.tok,
) )

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from parameterized import parameterized
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase): class ReceiptTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver) -> None:
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main self.store = homeserver.get_datastores().main
@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase):
) )
) )
def test_return_empty_with_no_data(self): def test_return_empty_with_no_data(self) -> None:
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID,
[
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, {}) self.assertEqual(res, {})
@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase):
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user_with_orderings( self.store.get_receipts_for_user_with_orderings(
OUR_USER_ID, OUR_USER_ID,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, {}) self.assertEqual(res, {})
@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase):
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id1, self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.UNSTABLE_READ_PRIVATE,
],
) )
) )
self.assertEqual(res, None) self.assertEqual(res, None)
def test_get_receipts_for_user(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_get_receipts_for_user(self, receipt_type: str) -> None:
# Send some events into the first room # Send some events into the first room
event1_1_id = self.create_and_send_event( event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID) self.room_id1, UserID.from_string(OTHER_USER_ID)
@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
) )
) )
# Test we get the latest event when we want both private and public receipts # Test we get the latest event when we want both private and public receipts
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
) )
) )
self.assertEqual(res, {self.room_id1: event1_2_id}) self.assertEqual(res, {self.room_id1: event1_2_id})
@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the public receipt # Test we get the latest event when we want only the public receipt
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type])
) )
self.assertEqual(res, {self.room_id1: event1_2_id}) self.assertEqual(res, {self.room_id1: event1_2_id})
@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
) )
) )
res = self.get_success( res = self.get_success(
self.store.get_receipts_for_user( self.store.get_receipts_for_user(
OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] OUR_USER_ID, [ReceiptTypes.READ, receipt_type]
) )
) )
self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
def test_get_last_receipt_event_id_for_user(self): @parameterized.expand(
[ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
)
def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None:
# Send some events into the first room # Send some events into the first room
event1_1_id = self.create_and_send_event( event1_1_id = self.create_and_send_event(
self.room_id1, UserID.from_string(OTHER_USER_ID) self.room_id1, UserID.from_string(OTHER_USER_ID)
@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {}
) )
) )
@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase):
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id1, self.room_id1,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [ReceiptTypes.READ, receipt_type],
) )
) )
self.assertEqual(res, event1_2_id) self.assertEqual(res, event1_2_id)
@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test we get the latest event when we want only the private receipt # Test we get the latest event when we want only the private receipt
res = self.get_success( res = self.get_success(
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] OUR_USER_ID, self.room_id1, [receipt_type]
) )
) )
self.assertEqual(res, event1_2_id) self.assertEqual(res, event1_2_id)
@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {}
) )
) )
res = self.get_success( res = self.get_success(
self.store.get_last_receipt_event_id_for_user( self.store.get_last_receipt_event_id_for_user(
OUR_USER_ID, OUR_USER_ID,
self.room_id2, self.room_id2,
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], [ReceiptTypes.READ, receipt_type],
) )
) )
self.assertEqual(res, event2_1_id) self.assertEqual(res, event2_1_id)