mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 18:43:30 +01:00
Some refactors around receipts stream (#16426)
This commit is contained in:
parent
a01ee24734
commit
80ec81dcc5
16 changed files with 111 additions and 80 deletions
1
changelog.d/16426.misc
Normal file
1
changelog.d/16426.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor some code to simplify and better type receipts stream adjacent code.
|
|
@ -216,7 +216,7 @@ class ApplicationServicesHandler:
|
|||
|
||||
def notify_interested_services_ephemeral(
|
||||
self,
|
||||
stream_key: str,
|
||||
stream_key: StreamKeyType,
|
||||
new_token: Union[int, RoomStreamToken],
|
||||
users: Collection[Union[str, UserID]],
|
||||
) -> None:
|
||||
|
@ -326,7 +326,7 @@ class ApplicationServicesHandler:
|
|||
async def _notify_interested_services_ephemeral(
|
||||
self,
|
||||
services: List[ApplicationService],
|
||||
stream_key: str,
|
||||
stream_key: StreamKeyType,
|
||||
new_token: int,
|
||||
users: Collection[Union[str, UserID]],
|
||||
) -> None:
|
||||
|
|
|
@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
|
|||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.storage.push_rule import RuleNotFoundException
|
||||
from synapse.synapse_rust.push import get_base_rule_ids
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, StreamKeyType, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -114,7 +114,9 @@ class PushRulesHandler:
|
|||
user_id: the user ID the change is for.
|
||||
"""
|
||||
stream_id = self._main_store.get_max_push_rules_stream_id()
|
||||
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
self._notifier.on_new_event(
|
||||
StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
|
||||
)
|
||||
|
||||
async def push_rules_for_user(
|
||||
self, user: UserID
|
||||
|
|
|
@ -130,11 +130,10 @@ class ReceiptsHandler:
|
|||
|
||||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
||||
"""Takes a list of receipts, stores them and informs the notifier."""
|
||||
min_batch_id: Optional[int] = None
|
||||
max_batch_id: Optional[int] = None
|
||||
|
||||
receipts_persisted: List[ReadReceipt] = []
|
||||
for receipt in receipts:
|
||||
res = await self.store.insert_receipt(
|
||||
stream_id = await self.store.insert_receipt(
|
||||
receipt.room_id,
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
|
@ -143,30 +142,26 @@ class ReceiptsHandler:
|
|||
receipt.data,
|
||||
)
|
||||
|
||||
if not res:
|
||||
# res will be None if this receipt is 'old'
|
||||
if stream_id is None:
|
||||
# stream_id will be None if this receipt is 'old'
|
||||
continue
|
||||
|
||||
stream_id, max_persisted_id = res
|
||||
receipts_persisted.append(receipt)
|
||||
|
||||
if min_batch_id is None or stream_id < min_batch_id:
|
||||
min_batch_id = stream_id
|
||||
if max_batch_id is None or max_persisted_id > max_batch_id:
|
||||
max_batch_id = max_persisted_id
|
||||
|
||||
# Either both of these should be None or neither.
|
||||
if min_batch_id is None or max_batch_id is None:
|
||||
if not receipts_persisted:
|
||||
# no new receipts
|
||||
return False
|
||||
|
||||
affected_room_ids = list({r.room_id for r in receipts})
|
||||
max_batch_id = self.store.get_max_receipt_stream_id()
|
||||
|
||||
affected_room_ids = list({r.room_id for r in receipts_persisted})
|
||||
|
||||
self.notifier.on_new_event(
|
||||
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
|
||||
)
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
await self.hs.get_pusherpool().on_new_receipts(
|
||||
min_batch_id, max_batch_id, affected_room_ids
|
||||
{r.user_id for r in receipts_persisted}
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -126,7 +126,7 @@ class _NotifierUserStream:
|
|||
|
||||
def notify(
|
||||
self,
|
||||
stream_key: str,
|
||||
stream_key: StreamKeyType,
|
||||
stream_id: Union[int, RoomStreamToken],
|
||||
time_now_ms: int,
|
||||
) -> None:
|
||||
|
@ -454,7 +454,7 @@ class Notifier:
|
|||
|
||||
def on_new_event(
|
||||
self,
|
||||
stream_key: str,
|
||||
stream_key: StreamKeyType,
|
||||
new_token: Union[int, RoomStreamToken],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
rooms: Optional[StrCollection] = None,
|
||||
|
@ -655,30 +655,29 @@ class Notifier:
|
|||
events: List[Union[JsonDict, EventBase]] = []
|
||||
end_token = from_token
|
||||
|
||||
for name, source in self.event_sources.sources.get_sources():
|
||||
keyname = "%s_key" % name
|
||||
before_id = getattr(before_token, keyname)
|
||||
after_id = getattr(after_token, keyname)
|
||||
for keyname, source in self.event_sources.sources.get_sources():
|
||||
before_id = before_token.get_field(keyname)
|
||||
after_id = after_token.get_field(keyname)
|
||||
if before_id == after_id:
|
||||
continue
|
||||
|
||||
new_events, new_key = await source.get_new_events(
|
||||
user=user,
|
||||
from_key=getattr(from_token, keyname),
|
||||
from_key=from_token.get_field(keyname),
|
||||
limit=limit,
|
||||
is_guest=is_peeking,
|
||||
room_ids=room_ids,
|
||||
explicit_room_id=explicit_room_id,
|
||||
)
|
||||
|
||||
if name == "room":
|
||||
if keyname == StreamKeyType.ROOM:
|
||||
new_events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
user.to_string(),
|
||||
new_events,
|
||||
is_peeking=is_peeking,
|
||||
)
|
||||
elif name == "presence":
|
||||
elif keyname == StreamKeyType.PRESENCE:
|
||||
now = self.clock.time_msec()
|
||||
new_events[:] = [
|
||||
{
|
||||
|
|
|
@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
def on_new_receipts(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
@ -99,7 +99,7 @@ class EmailPusher(Pusher):
|
|||
pass
|
||||
self.timed_call = None
|
||||
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
def on_new_receipts(self) -> None:
|
||||
# We could wake up and cancel the timer but there tend to be quite a
|
||||
# lot of read receipts so it's probably less work to just let the
|
||||
# timer fire
|
||||
|
|
|
@ -160,7 +160,7 @@ class HttpPusher(Pusher):
|
|||
if should_check_for_notifs:
|
||||
self._start_processing()
|
||||
|
||||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
|
||||
def on_new_receipts(self) -> None:
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
|
||||
# We could check the receipts are actually m.read receipts here,
|
||||
|
|
|
@ -292,20 +292,12 @@ class PusherPool:
|
|||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_notifications")
|
||||
|
||||
async def on_new_receipts(
|
||||
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
|
||||
) -> None:
|
||||
async def on_new_receipts(self, users_affected: StrCollection) -> None:
|
||||
if not self.pushers:
|
||||
# nothing to do here.
|
||||
return
|
||||
|
||||
try:
|
||||
# Need to subtract 1 from the minimum because the lower bound here
|
||||
# is not inclusive
|
||||
users_affected = await self.store.get_users_sent_receipts_between(
|
||||
min_stream_id - 1, max_stream_id
|
||||
)
|
||||
|
||||
for u in users_affected:
|
||||
# Don't push if the user account has expired
|
||||
expired = await self._account_validity_handler.is_user_expired(u)
|
||||
|
@ -314,7 +306,7 @@ class PusherPool:
|
|||
|
||||
if u in self.pushers:
|
||||
for p in self.pushers[u].values():
|
||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||
p.on_new_receipts()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
|
|
@ -129,9 +129,7 @@ class ReplicationDataHandler:
|
|||
self.notifier.on_new_event(
|
||||
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
await self._pusher_pool.on_new_receipts(
|
||||
token, token, {row.room_id for row in rows}
|
||||
)
|
||||
await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
|
||||
elif stream_name == ToDeviceStream.NAME:
|
||||
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||
if entities:
|
||||
|
|
|
@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
|||
"message": "Set room key",
|
||||
"room_id": room_id,
|
||||
"session_id": session_id,
|
||||
StreamKeyType.ROOM: room_key,
|
||||
StreamKeyType.ROOM.value: room_key,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
event_ids: List[str],
|
||||
thread_id: Optional[str],
|
||||
data: dict,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
) -> Optional[int]:
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
|
||||
Automatically does conversion between linearized and graph
|
||||
|
@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
data,
|
||||
)
|
||||
|
||||
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||
|
||||
return stream_id, max_persisted_id
|
||||
return stream_id
|
||||
|
||||
async def _insert_graph_receipt(
|
||||
self,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Iterator, Tuple
|
||||
from typing import TYPE_CHECKING, Sequence, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
|
|||
from synapse.handlers.typing import TypingNotificationEventSource
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.streams import EventSource
|
||||
from synapse.types import StreamToken
|
||||
from synapse.types import StreamKeyType, StreamToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -37,9 +37,14 @@ class _EventSourcesInner:
|
|||
receipt: ReceiptEventSource
|
||||
account_data: AccountDataEventSource
|
||||
|
||||
def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
|
||||
for attribute in attr.fields(_EventSourcesInner):
|
||||
yield attribute.name, getattr(self, attribute.name)
|
||||
def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
|
||||
return [
|
||||
(StreamKeyType.ROOM, self.room),
|
||||
(StreamKeyType.PRESENCE, self.presence),
|
||||
(StreamKeyType.TYPING, self.typing),
|
||||
(StreamKeyType.RECEIPT, self.receipt),
|
||||
(StreamKeyType.ACCOUNT_DATA, self.account_data),
|
||||
]
|
||||
|
||||
|
||||
class EventSources:
|
||||
|
|
|
@ -22,8 +22,8 @@ from typing import (
|
|||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Final,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Match,
|
||||
MutableMapping,
|
||||
|
@ -34,6 +34,7 @@ from typing import (
|
|||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -649,20 +650,20 @@ class RoomStreamToken:
|
|||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
class StreamKeyType:
|
||||
class StreamKeyType(Enum):
|
||||
"""Known stream types.
|
||||
|
||||
A stream is a list of entities ordered by an incrementing "stream token".
|
||||
"""
|
||||
|
||||
ROOM: Final = "room_key"
|
||||
PRESENCE: Final = "presence_key"
|
||||
TYPING: Final = "typing_key"
|
||||
RECEIPT: Final = "receipt_key"
|
||||
ACCOUNT_DATA: Final = "account_data_key"
|
||||
PUSH_RULES: Final = "push_rules_key"
|
||||
TO_DEVICE: Final = "to_device_key"
|
||||
DEVICE_LIST: Final = "device_list_key"
|
||||
ROOM = "room_key"
|
||||
PRESENCE = "presence_key"
|
||||
TYPING = "typing_key"
|
||||
RECEIPT = "receipt_key"
|
||||
ACCOUNT_DATA = "account_data_key"
|
||||
PUSH_RULES = "push_rules_key"
|
||||
TO_DEVICE = "to_device_key"
|
||||
DEVICE_LIST = "device_list_key"
|
||||
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
|
||||
|
||||
|
||||
|
@ -784,7 +785,7 @@ class StreamToken:
|
|||
def room_stream_id(self) -> int:
|
||||
return self.room_key.stream
|
||||
|
||||
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
|
||||
def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
|
||||
"""Advance the given key in the token to a new value if and only if the
|
||||
new value is after the old value.
|
||||
|
||||
|
@ -797,16 +798,44 @@ class StreamToken:
|
|||
return new_token
|
||||
|
||||
new_token = self.copy_and_replace(key, new_value)
|
||||
new_id = int(getattr(new_token, key))
|
||||
old_id = int(getattr(self, key))
|
||||
new_id = new_token.get_field(key)
|
||||
old_id = self.get_field(key)
|
||||
|
||||
if old_id < new_id:
|
||||
return new_token
|
||||
else:
|
||||
return self
|
||||
|
||||
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
|
||||
return attr.evolve(self, **{key: new_value})
|
||||
def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
|
||||
return attr.evolve(self, **{key.value: new_value})
|
||||
|
||||
@overload
|
||||
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_field(
|
||||
self,
|
||||
key: Literal[
|
||||
StreamKeyType.ACCOUNT_DATA,
|
||||
StreamKeyType.DEVICE_LIST,
|
||||
StreamKeyType.PRESENCE,
|
||||
StreamKeyType.PUSH_RULES,
|
||||
StreamKeyType.RECEIPT,
|
||||
StreamKeyType.TO_DEVICE,
|
||||
StreamKeyType.TYPING,
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
|
||||
],
|
||||
) -> int:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
|
||||
...
|
||||
|
||||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
|
||||
"""Returns the stream ID for the given key."""
|
||||
return getattr(self, key.value)
|
||||
|
||||
|
||||
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -31,7 +31,7 @@ from synapse.appservice import (
|
|||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.rest.client import login, receipts, register, room, sendtodevice
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomStreamToken
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.handler.notify_interested_services_ephemeral(
|
||||
"receipt_key", 580, ["@fakerecipient:example.com"]
|
||||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
|
||||
)
|
||||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||
interested_service, ephemeral=[event]
|
||||
|
@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.handler.notify_interested_services_ephemeral(
|
||||
"receipt_key", 580, ["@fakerecipient:example.com"]
|
||||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
|
||||
)
|
||||
# This method will be called, but with an empty list of events
|
||||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
||||
|
@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
|||
self.get_success(
|
||||
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
|
||||
services=[interested_appservice],
|
||||
stream_key="receipt_key",
|
||||
stream_key=StreamKeyType.RECEIPT,
|
||||
new_token=stream_token,
|
||||
users=[self.exclusive_as_user],
|
||||
)
|
||||
|
|
|
@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer
|
|||
from synapse.handlers.typing import TypingWriterHandler
|
||||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 1)
|
||||
events = self.get_success(
|
||||
|
@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 1)
|
||||
events = self.get_success(
|
||||
|
@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.mock_federation_client.put_json.assert_called_once_with(
|
||||
"farm",
|
||||
|
@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 1)
|
||||
|
@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.reactor.pump([16])
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 2)
|
||||
events = self.get_success(
|
||||
|
@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
|
||||
)
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEqual(self.event_source.get_current_key(), 3)
|
||||
|
|
Loading…
Reference in a new issue