0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-09-28 04:28:58 +02:00

Support MSC4140: Delayed events (Futures) (#17326)

This commit is contained in:
Andrew Ferrazzutti 2024-09-23 08:33:48 -04:00 committed by GitHub
parent 75e2c17d2a
commit 5173741c71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1772 additions and 12 deletions

View file

@ -0,0 +1,2 @@
Add initial implementation of delayed events as proposed by [MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).

View file

@ -111,6 +111,9 @@ server_notices:
system_mxid_avatar_url: ""
room_name: "Server Alert"
# Enable delayed events (msc4140)
max_event_delay_duration: 24h
# Disable sync cache so that initial `/sync` requests are up-to-date.
caches:

View file

@ -761,6 +761,19 @@ email:
password_reset: "[%(server_name)s] Password reset"
email_validation: "[%(server_name)s] Validate your email"
```
---
### `max_event_delay_duration`
The maximum allowed duration by which sent events can be delayed, as per
[MSC4140](https://github.com/matrix-org/matrix-spec-proposals/pull/4140).
Must be a positive value if set.
Defaults to no duration (`null`), which disallows sending delayed events.
Example configuration:
```yaml
max_event_delay_duration: 24h
```
## Homeserver blocking
Useful options for Synapse admins.

View file

@ -290,6 +290,7 @@ information.
Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
^/_matrix/client/unstable/org.matrix.msc4140/delayed_events
Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to

View file

@ -223,6 +223,7 @@ test_packages=(
./tests/msc3930
./tests/msc3902
./tests/msc3967
./tests/msc4140
)
# Enable dirty runs, so tests will reuse the same container where possible.

View file

@ -65,6 +65,7 @@ from synapse.storage.databases.main.appservice import (
)
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.directory import DirectoryWorkerStore
@ -161,6 +162,7 @@ class GenericWorkerStore(
TaskSchedulerWorkerStore,
ExperimentalFeaturesStore,
SlidingSyncStore,
DelayedEventsStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -780,6 +780,17 @@ class ServerConfig(Config):
else:
self.delete_stale_devices_after = None
# The maximum allowed delay duration for delayed events (MSC4140).
max_event_delay_duration = config.get("max_event_delay_duration")
if max_event_delay_duration is not None:
self.max_event_delay_ms: Optional[int] = self.parse_duration(
max_event_delay_duration
)
if self.max_event_delay_ms <= 0:
raise ConfigError("max_event_delay_duration must be a positive value")
else:
self.max_event_delay_ms = None
def has_tls_listener(self) -> bool:
return any(listener.is_tls() for listener in self.listeners)

View file

@ -0,0 +1,484 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.api.errors import ShadowBanError
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.logging.opentracing import set_tag
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.delayed_events import (
ReplicationAddedDelayedEventRestServlet,
)
from synapse.storage.databases.main.delayed_events import (
DelayedEventDetails,
DelayID,
EventType,
StateKey,
Timestamp,
UserLocalpart,
)
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
JsonDict,
Requester,
RoomID,
UserID,
create_requester,
)
from synapse.util.events import generate_fake_event_id
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class DelayedEventsHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._config = hs.config
self._clock = hs.get_clock()
self._request_ratelimiter = hs.get_request_ratelimiter()
self._event_creation_handler = hs.get_event_creation_handler()
self._room_member_handler = hs.get_room_member_handler()
self._next_delayed_event_call: Optional[IDelayedCall] = None
# The current position in the current_state_delta stream
self._event_pos: Optional[int] = None
# Guard to ensure we only process event deltas one at a time
self._event_processing = False
if hs.config.worker.worker_app is None:
self._repl_client = None
async def _schedule_db_events() -> None:
# We kick this off to pick up outstanding work from before the last restart.
# Block until we're up to date.
await self._unsafe_process_new_event()
hs.get_notifier().add_replication_callback(self.notify_new_event)
# Kick off again (without blocking) to catch any missed notifications
# that may have fired before the callback was added.
self._clock.call_later(0, self.notify_new_event)
# Delayed events that are already marked as processed on startup might not have been
# sent properly on the last run of the server, so unmark them to send them again.
# Caveat: this will double-send delayed events that successfully persisted, but failed
# to be removed from the DB table of delayed events.
# TODO: To avoid double-sending, scan the timeline to find which of these events were
# already sent. To do so, must store delay_ids in sent events to retrieve them later.
await self._store.unprocess_delayed_events()
events, next_send_ts = await self._store.process_timeout_delayed_events(
self._get_current_ts()
)
if next_send_ts:
self._schedule_next_at(next_send_ts)
# Can send the events in background after having awaited on marking them as processed
run_as_background_process(
"_send_events",
self._send_events,
events,
)
self._initialized_from_db = run_as_background_process(
"_schedule_db_events", _schedule_db_events
)
else:
self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs)
@property
def _is_master(self) -> bool:
return self._repl_client is None
def notify_new_event(self) -> None:
"""
Called when there may be more state event deltas to process,
which should cancel pending delayed events for the same state.
"""
if self._event_processing:
return
self._event_processing = True
async def process() -> None:
try:
await self._unsafe_process_new_event()
finally:
self._event_processing = False
run_as_background_process("delayed_events.notify_new_event", process)
async def _unsafe_process_new_event(self) -> None:
# If self._event_pos is None then means we haven't fetched it from the DB yet
if self._event_pos is None:
self._event_pos = await self._store.get_delayed_events_stream_pos()
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
if self._event_pos > room_max_stream_ordering:
# apparently, we've processed more events than exist in the database!
# this can happen if events are removed with history purge or similar.
logger.warning(
"Event stream ordering appears to have gone backwards (%i -> %i): "
"rewinding delayed events processor",
self._event_pos,
room_max_stream_ordering,
)
self._event_pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date
while True:
with Measure(self._clock, "delayed_events_delta"):
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
if self._event_pos == room_max_stream_ordering:
return
logger.debug(
"Processing delayed events %s->%s",
self._event_pos,
room_max_stream_ordering,
)
(
max_pos,
deltas,
) = await self._storage_controllers.state.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
logger.debug(
"Handling %d state deltas for delayed events processing",
len(deltas),
)
await self._handle_state_deltas(deltas)
self._event_pos = max_pos
# Expose current event processing position to prometheus
event_processing_positions.labels("delayed_events").set(max_pos)
await self._store.update_delayed_events_stream_pos(max_pos)
async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
"""
Process current state deltas to cancel pending delayed events
that target the same state.
"""
for delta in deltas:
logger.debug(
"Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
)
next_send_ts = await self._store.cancel_delayed_state_events(
room_id=delta.room_id,
event_type=delta.event_type,
state_key=delta.state_key,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def add(
self,
requester: Requester,
*,
room_id: str,
event_type: str,
state_key: Optional[str],
origin_server_ts: Optional[int],
content: JsonDict,
delay: int,
) -> str:
"""
Creates a new delayed event and schedules its delivery.
Args:
requester: The requester of the delayed event, who will be its owner.
room_id: The ID of the room where the event should be sent to.
event_type: The type of event to be sent.
state_key: The state key of the event to be sent, or None if it is not a state event.
origin_server_ts: The custom timestamp to send the event with.
If None, the timestamp will be the actual time when the event is sent.
content: The content of the event to be sent.
delay: How long (in milliseconds) to wait before automatically sending the event.
Returns: The ID of the added delayed event.
Raises:
SynapseError: if the delayed event fails validation checks.
"""
await self._request_ratelimiter.ratelimit(requester)
self._event_creation_handler.validator.validate_builder(
self._event_creation_handler.event_builder_factory.for_room_version(
await self._store.get_room_version(room_id),
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": str(requester.user),
**({"state_key": state_key} if state_key is not None else {}),
},
)
)
creation_ts = self._get_current_ts()
delay_id, next_send_ts = await self._store.add_delayed_event(
user_localpart=requester.user.localpart,
device_id=requester.device_id,
creation_ts=creation_ts,
room_id=room_id,
event_type=event_type,
state_key=state_key,
origin_server_ts=origin_server_ts,
content=content,
delay=delay,
)
if self._repl_client is not None:
# NOTE: If this throws, the delayed event will remain in the DB and
# will be picked up once the main worker gets another delayed event.
await self._repl_client(
instance_name=MAIN_PROCESS_INSTANCE_NAME,
next_send_ts=next_send_ts,
)
elif self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
return delay_id
def on_added(self, next_send_ts: int) -> None:
next_send_ts = Timestamp(next_send_ts)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def cancel(self, requester: Requester, delay_id: str) -> None:
"""
Cancels the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
next_send_ts = await self._store.cancel_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def restart(self, requester: Requester, delay_id: str) -> None:
"""
Restarts the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
next_send_ts = await self._store.restart_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
current_ts=self._get_current_ts(),
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def send(self, requester: Requester, delay_id: str) -> None:
"""
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
event, next_send_ts = await self._store.process_target_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
await self._send_event(
DelayedEventDetails(
delay_id=DelayID(delay_id),
user_localpart=UserLocalpart(requester.user.localpart),
room_id=event.room_id,
type=event.type,
state_key=event.state_key,
origin_server_ts=event.origin_server_ts,
content=event.content,
device_id=event.device_id,
)
)
async def _send_on_timeout(self) -> None:
self._next_delayed_event_call = None
events, next_send_ts = await self._store.process_timeout_delayed_events(
self._get_current_ts()
)
if next_send_ts:
self._schedule_next_at(next_send_ts)
await self._send_events(events)
async def _send_events(self, events: List[DelayedEventDetails]) -> None:
sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set()
for event in events:
if event.state_key is not None:
state_info = (event.room_id, event.type, event.state_key)
if state_info in sent_state:
continue
else:
state_info = None
try:
# TODO: send in background if message event or non-conflicting state event
await self._send_event(event)
if state_info is not None:
sent_state.add(state_info)
except Exception:
logger.exception("Failed to send delayed event")
for room_id, event_type, state_key in sent_state:
await self._store.delete_processed_delayed_state_events(
room_id=str(room_id),
event_type=event_type,
state_key=state_key,
)
def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None:
if next_send_ts is not None:
self._schedule_next_at(next_send_ts)
elif self._next_delayed_event_call is not None:
self._next_delayed_event_call.cancel()
self._next_delayed_event_call = None
def _schedule_next_at(self, next_send_ts: Timestamp) -> None:
delay = next_send_ts - self._get_current_ts()
delay_sec = delay / 1000 if delay > 0 else 0
if self._next_delayed_event_call is None:
self._next_delayed_event_call = self._clock.call_later(
delay_sec,
run_as_background_process,
"_send_on_timeout",
self._send_on_timeout,
)
else:
self._next_delayed_event_call.reset(delay_sec)
async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
"""Return all pending delayed events requested by the given user."""
await self._request_ratelimiter.ratelimit(requester)
return await self._store.get_all_delayed_events_for_user(
requester.user.localpart
)
async def _send_event(
self,
event: DelayedEventDetails,
txn_id: Optional[str] = None,
) -> None:
user_id = UserID(event.user_localpart, self._config.server.server_name)
user_id_str = user_id.to_string()
# Create a new requester from what data is currently available
requester = create_requester(
user_id,
is_guest=await self._store.is_guest(user_id_str),
device_id=event.device_id,
)
try:
if event.state_key is not None and event.type == EventTypes.Member:
membership = event.content.get("membership")
assert membership is not None
event_id, _ = await self._room_member_handler.update_membership(
requester,
target=UserID.from_string(event.state_key),
room_id=event.room_id.to_string(),
action=membership,
content=event.content,
origin_server_ts=event.origin_server_ts,
)
else:
event_dict: JsonDict = {
"type": event.type,
"content": event.content,
"room_id": event.room_id.to_string(),
"sender": user_id_str,
}
if event.origin_server_ts is not None:
event_dict["origin_server_ts"] = event.origin_server_ts
if event.state_key is not None:
event_dict["state_key"] = event.state_key
(
sent_event,
_,
) = await self._event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
)
event_id = sent_event.event_id
except ShadowBanError:
event_id = generate_fake_event_id()
finally:
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
try:
await self._store.delete_processed_delayed_event(
event.delay_id, event.user_localpart
)
except Exception:
logger.exception("Failed to delete processed delayed event")
set_tag("event_id", event_id)
def _get_current_ts(self) -> Timestamp:
return Timestamp(self._clock.time_msec())
def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool:
# The DB alone knows if the next send time changed after adding/modifying
# a delayed event, but if we were to ever miss updating our delayed call's
# firing time, we may miss other updates. So, keep track of changes to the
# the next send time here instead of in the DB.
cached_next_send_ts = (
int(self._next_delayed_event_call.getTime() * 1000)
if self._next_delayed_event_call is not None
else None
)
return next_send_ts != cached_next_send_ts

View file

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING
from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
delayed_events,
devices,
federation,
login,
@ -64,3 +65,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
delayed_events.register_servlets(hs, self)

View file

@ -0,0 +1,48 @@
import logging
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, JsonMapping
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationAddedDelayedEventRestServlet(ReplicationEndpoint):
"""Handle a delayed event being added by another worker.
Request format:
POST /_synapse/replication/delayed_event_added/
{}
"""
NAME = "added_delayed_event"
PATH_ARGS = ()
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.handler = hs.get_delayed_events_handler()
@staticmethod
async def _serialize_payload(next_send_ts: int) -> JsonDict: # type: ignore[override]
return {"next_send_ts": next_send_ts}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict
) -> Tuple[int, Dict[str, Optional[JsonMapping]]]:
self.handler.on_added(int(content["next_send_ts"]))
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationAddedDelayedEventRestServlet(hs).register(http_server)

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -31,6 +31,7 @@ from synapse.rest.client import (
auth,
auth_issuer,
capabilities,
delayed_events,
devices,
directory,
events,
@ -81,6 +82,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = (
room.register_deprecated_servlets,
events.register_servlets,
room.register_servlets,
delayed_events.register_servlets,
login.register_servlets,
profile.register_servlets,
presence.register_servlets,

View file

@ -0,0 +1,97 @@
# This module contains REST servlets to do with delayed events: /delayed_events/<paths>
import logging
from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class _UpdateDelayedEventAction(Enum):
CANCEL = "cancel"
RESTART = "restart"
SEND = "send"
class UpdateDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
try:
action = str(body["action"])
except KeyError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"'action' is missing",
Codes.MISSING_PARAM,
)
try:
enum_action = _UpdateDelayedEventAction(action)
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"'action' is not one of "
+ ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction),
Codes.INVALID_PARAM,
)
if enum_action == _UpdateDelayedEventAction.CANCEL:
await self.delayed_events_handler.cancel(requester, delay_id)
elif enum_action == _UpdateDelayedEventAction.RESTART:
await self.delayed_events_handler.restart(requester, delay_id)
elif enum_action == _UpdateDelayedEventAction.SEND:
await self.delayed_events_handler.send(requester, delay_id)
return 200, {}
class DelayedEventsServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
# TODO: Support Pagination stream API ("from" query parameter)
delayed_events = await self.delayed_events_handler.get_all_for_user(requester)
ret = {"delayed_events": delayed_events}
return 200, ret
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
UpdateDelayedEventServlet(hs).register(http_server)
DelayedEventsServlet(hs).register(http_server)

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -195,7 +195,9 @@ class RoomStateEventRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.delayed_events_handler = hs.get_delayed_events_handler()
self.auth = hs.get_auth()
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/state/$eventtype
@ -291,6 +293,22 @@ class RoomStateEventRestServlet(RestServlet):
if requester.app_service:
origin_server_ts = parse_integer(request, "ts")
delay = _parse_request_delay(request, self._max_event_delay_ms)
if delay is not None:
delay_id = await self.delayed_events_handler.add(
requester,
room_id=room_id,
event_type=event_type,
state_key=state_key,
origin_server_ts=origin_server_ts,
content=content,
delay=delay,
)
set_tag("delay_id", delay_id)
ret = {"delay_id": delay_id}
return 200, ret
try:
if event_type == EventTypes.Member:
membership = content.get("membership", None)
@ -341,7 +359,9 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.delayed_events_handler = hs.get_delayed_events_handler()
self.auth = hs.get_auth()
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
@ -358,6 +378,26 @@ class RoomSendEventRestServlet(TransactionRestServlet):
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
origin_server_ts = None
if requester.app_service:
origin_server_ts = parse_integer(request, "ts")
delay = _parse_request_delay(request, self._max_event_delay_ms)
if delay is not None:
delay_id = await self.delayed_events_handler.add(
requester,
room_id=room_id,
event_type=event_type,
state_key=None,
origin_server_ts=origin_server_ts,
content=content,
delay=delay,
)
set_tag("delay_id", delay_id)
ret = {"delay_id": delay_id}
return 200, ret
event_dict: JsonDict = {
"type": event_type,
"content": content,
@ -365,8 +405,6 @@ class RoomSendEventRestServlet(TransactionRestServlet):
"sender": requester.user.to_string(),
}
if requester.app_service:
origin_server_ts = parse_integer(request, "ts")
if origin_server_ts is not None:
event_dict["origin_server_ts"] = origin_server_ts
@ -411,6 +449,49 @@ class RoomSendEventRestServlet(TransactionRestServlet):
)
def _parse_request_delay(
request: SynapseRequest,
max_delay: Optional[int],
) -> Optional[int]:
"""Parses from the request string the delay parameter for
delayed event requests, and checks it for correctness.
Args:
request: the twisted HTTP request.
max_delay: the maximum allowed value of the delay parameter,
or None if no delay parameter is allowed.
Returns:
The value of the requested delay, or None if it was absent.
Raises:
SynapseError: if the delay parameter is present and forbidden,
or if it exceeds the maximum allowed value.
"""
delay = parse_integer(request, "org.matrix.msc4140.delay")
if delay is None:
return None
if max_delay is None:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Delayed events are not supported on this server",
Codes.UNKNOWN,
{
"org.matrix.msc4140.errcode": "M_MAX_DELAY_UNSUPPORTED",
},
)
if delay > max_delay:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"The requested delay exceeds the allowed maximum.",
Codes.UNKNOWN,
{
"org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED",
"org.matrix.msc4140.max_delay": max_delay,
},
)
return delay
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
CATEGORY = "Event sending requests"

View file

@ -4,7 +4,7 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2017 Vector Creations Ltd
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -171,6 +171,8 @@ class VersionsRestServlet(RestServlet):
is not None
)
),
# MSC4140: Delayed events
"org.matrix.msc4140": True,
# MSC4151: Report room API (Client-Server API)
"org.matrix.msc4151": self.config.experimental.msc4151_enabled,
# Simplified sliding sync

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -68,6 +68,7 @@ from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, PasswordAuthProvider
from synapse.handlers.cas import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.delayed_events import DelayedEventsHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.directory import DirectoryHandler
@ -251,6 +252,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"account_validity",
"auth",
"deactivate_account",
"delayed_events",
"message",
"pagination",
"profile",
@ -964,3 +966,7 @@ class HomeServer(metaclass=abc.ABCMeta):
register_threadpool("media", media_threadpool)
return media_threadpool
@cache_in_self
def get_delayed_events_handler(self) -> DelayedEventsHandler:
return DelayedEventsHandler(self)

View file

@ -3,7 +3,7 @@
#
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -44,6 +44,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
from .client_ips import ClientIpWorkerStore
from .delayed_events import DelayedEventsStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@ -158,6 +159,7 @@ class DataStore(
SessionStore,
TaskSchedulerWorkerStore,
SlidingSyncStore,
DelayedEventsStore,
):
def __init__(
self,

View file

@ -0,0 +1,523 @@
import logging
from typing import List, NewType, Optional, Tuple
import attr
from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction, StoreError
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomID
from synapse.util import json_encoder, stringutils as stringutils
logger = logging.getLogger(__name__)
DelayID = NewType("DelayID", str)
UserLocalpart = NewType("UserLocalpart", str)
DeviceID = NewType("DeviceID", str)
EventType = NewType("EventType", str)
StateKey = NewType("StateKey", str)
Delay = NewType("Delay", int)
Timestamp = NewType("Timestamp", int)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventDetails:
room_id: RoomID
type: EventType
state_key: Optional[StateKey]
origin_server_ts: Optional[Timestamp]
content: JsonDict
device_id: Optional[DeviceID]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DelayedEventDetails(EventDetails):
delay_id: DelayID
user_localpart: UserLocalpart
class DelayedEventsStore(SQLBaseStore):
async def get_delayed_events_stream_pos(self) -> int:
"""
Gets the stream position of the background process to watch for state events
that target the same piece of state as any pending delayed events.
"""
return await self.db_pool.simple_select_one_onecol(
table="delayed_events_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_delayed_events_stream_pos",
)
async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None:
"""
Updates the stream position of the background process to watch for state events
that target the same piece of state as any pending delayed events.
Must only be used by the worker running the background process.
"""
await self.db_pool.simple_update_one(
table="delayed_events_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
desc="update_delayed_events_stream_pos",
)
async def add_delayed_event(
self,
*,
user_localpart: str,
device_id: Optional[str],
creation_ts: Timestamp,
room_id: str,
event_type: str,
state_key: Optional[str],
origin_server_ts: Optional[int],
content: JsonDict,
delay: int,
) -> Tuple[DelayID, Timestamp]:
"""
Inserts a new delayed event in the DB.
Returns: The generated ID assigned to the added delayed event,
and the send time of the next delayed event to be sent,
which is either the event just added or one added earlier.
"""
delay_id = _generate_delay_id()
send_ts = Timestamp(creation_ts + delay)
def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp:
self.db_pool.simple_insert_txn(
txn,
table="delayed_events",
values={
"delay_id": delay_id,
"user_localpart": user_localpart,
"device_id": device_id,
"delay": delay,
"send_ts": send_ts,
"room_id": room_id,
"event_type": event_type,
"state_key": state_key,
"origin_server_ts": origin_server_ts,
"content": json_encoder.encode(content),
},
)
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
assert next_send_ts is not None
return next_send_ts
next_send_ts = await self.db_pool.runInteraction(
"add_delayed_event", add_delayed_event_txn
)
return delay_id, next_send_ts
async def restart_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
current_ts: Timestamp,
) -> Timestamp:
"""
Restarts the send time of the matching delayed event,
as long as it hasn't already been marked for processing.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
current_ts: The current time, which will be used to calculate the new send time.
Returns: The send time of the next delayed event to be sent,
which is either the event just restarted, or another one
with an earlier send time than the restarted one's new send time.
Raises:
NotFoundError: if there is no matching delayed event.
"""
def restart_delayed_event_txn(
txn: LoggingTransaction,
) -> Timestamp:
txn.execute(
"""
UPDATE delayed_events
SET send_ts = ? + delay
WHERE delay_id = ? AND user_localpart = ?
AND NOT is_processed
""",
(
current_ts,
delay_id,
user_localpart,
),
)
if txn.rowcount == 0:
raise NotFoundError("Delayed event not found")
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
assert next_send_ts is not None
return next_send_ts
return await self.db_pool.runInteraction(
"restart_delayed_event", restart_delayed_event_txn
)
async def get_all_delayed_events_for_user(
self,
user_localpart: str,
) -> List[JsonDict]:
"""Returns all pending delayed events owned by the given user."""
# TODO: Support Pagination stream API ("next_batch" field)
rows = await self.db_pool.execute(
"get_all_delayed_events_for_user",
"""
SELECT
delay_id,
room_id,
event_type,
state_key,
delay,
send_ts,
content
FROM delayed_events
WHERE user_localpart = ? AND NOT is_processed
ORDER BY send_ts
""",
user_localpart,
)
return [
{
"delay_id": DelayID(row[0]),
"room_id": str(RoomID.from_string(row[1])),
"type": EventType(row[2]),
**({"state_key": StateKey(row[3])} if row[3] is not None else {}),
"delay": Delay(row[4]),
"running_since": Timestamp(row[5] - row[4]),
"content": db_to_json(row[6]),
}
for row in rows
]
async def process_timeout_delayed_events(
self, current_ts: Timestamp
) -> Tuple[
List[DelayedEventDetails],
Optional[Timestamp],
]:
"""
Marks for processing all delayed events that should have been sent prior to the provided time
that haven't already been marked as such.
Returns: The details of all newly-processed delayed events,
and the send time of the next delayed event to be sent, if any.
"""
def process_timeout_delayed_events_txn(
txn: LoggingTransaction,
) -> Tuple[
List[DelayedEventDetails],
Optional[Timestamp],
]:
sql_cols = ", ".join(
(
"delay_id",
"user_localpart",
"room_id",
"event_type",
"state_key",
"origin_server_ts",
"send_ts",
"content",
"device_id",
)
)
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
sql_where = "WHERE send_ts <= ? AND NOT is_processed"
sql_args = (current_ts,)
sql_order = "ORDER BY send_ts"
if isinstance(self.database_engine, PostgresEngine):
# Do this only in Postgres because:
# - SQLite's RETURNING emits rows in an arbitrary order
# - https://www.sqlite.org/lang_returning.html#limitations_and_caveats
# - SQLite does not support data-modifying statements in a WITH clause
# - https://www.sqlite.org/lang_with.html
# - https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-MODIFYING
txn.execute(
f"""
WITH events_to_send AS (
{sql_update} {sql_where} RETURNING *
) SELECT {sql_cols} FROM events_to_send {sql_order}
""",
sql_args,
)
rows = txn.fetchall()
else:
txn.execute(
f"SELECT {sql_cols} FROM delayed_events {sql_where} {sql_order}",
sql_args,
)
rows = txn.fetchall()
txn.execute(f"{sql_update} {sql_where}", sql_args)
assert txn.rowcount == len(rows)
events = [
DelayedEventDetails(
RoomID.from_string(row[2]),
EventType(row[3]),
StateKey(row[4]) if row[4] is not None else None,
# If no custom_origin_ts is set, use send_ts as the event's timestamp
Timestamp(row[5] if row[5] is not None else row[6]),
db_to_json(row[7]),
DeviceID(row[8]) if row[8] is not None else None,
DelayID(row[0]),
UserLocalpart(row[1]),
)
for row in rows
]
next_send_ts = self._get_next_delayed_event_send_ts_txn(txn)
return events, next_send_ts
return await self.db_pool.runInteraction(
"process_timeout_delayed_events", process_timeout_delayed_events_txn
)
async def process_target_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> Tuple[
EventDetails,
Optional[Timestamp],
]:
"""
Marks for processing the matching delayed event, regardless of its timeout time,
as long as it has not already been marked as such.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The details of the matching delayed event,
and the send time of the next delayed event to be sent, if any.
Raises:
NotFoundError: if there is no matching delayed event.
"""
def process_target_delayed_event_txn(
txn: LoggingTransaction,
) -> Tuple[
EventDetails,
Optional[Timestamp],
]:
sql_cols = ", ".join(
(
"room_id",
"event_type",
"state_key",
"origin_server_ts",
"content",
"device_id",
)
)
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
sql_where = "WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed"
sql_args = (delay_id, user_localpart)
txn.execute(
(
f"{sql_update} {sql_where} RETURNING {sql_cols}"
if self.database_engine.supports_returning
else f"SELECT {sql_cols} FROM delayed_events {sql_where}"
),
sql_args,
)
row = txn.fetchone()
if row is None:
raise NotFoundError("Delayed event not found")
elif not self.database_engine.supports_returning:
txn.execute(f"{sql_update} {sql_where}", sql_args)
assert txn.rowcount == 1
event = EventDetails(
RoomID.from_string(row[0]),
EventType(row[1]),
StateKey(row[2]) if row[2] is not None else None,
Timestamp(row[3]) if row[3] is not None else None,
db_to_json(row[4]),
DeviceID(row[5]) if row[5] is not None else None,
)
return event, self._get_next_delayed_event_send_ts_txn(txn)
return await self.db_pool.runInteraction(
"process_target_delayed_event", process_target_delayed_event_txn
)
async def cancel_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> Optional[Timestamp]:
"""
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The send time of the next delayed event to be sent, if any.
Raises:
NotFoundError: if there is no matching delayed event.
"""
def cancel_delayed_event_txn(
txn: LoggingTransaction,
) -> Optional[Timestamp]:
try:
self.db_pool.simple_delete_one_txn(
txn,
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": False,
},
)
except StoreError:
if txn.rowcount == 0:
raise NotFoundError("Delayed event not found")
else:
raise
return self._get_next_delayed_event_send_ts_txn(txn)
return await self.db_pool.runInteraction(
"cancel_delayed_event", cancel_delayed_event_txn
)
async def cancel_delayed_state_events(
self,
*,
room_id: str,
event_type: str,
state_key: str,
) -> Optional[Timestamp]:
"""
Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed.
Returns: The send time of the next delayed event to be sent, if any.
"""
def cancel_delayed_state_events_txn(
txn: LoggingTransaction,
) -> Optional[Timestamp]:
self.db_pool.simple_delete_txn(
txn,
table="delayed_events",
keyvalues={
"room_id": room_id,
"event_type": event_type,
"state_key": state_key,
"is_processed": False,
},
)
return self._get_next_delayed_event_send_ts_txn(txn)
return await self.db_pool.runInteraction(
"cancel_delayed_state_events", cancel_delayed_state_events_txn
)
async def delete_processed_delayed_event(
self,
delay_id: DelayID,
user_localpart: UserLocalpart,
) -> None:
"""
Delete the matching delayed event, as long as it has been marked as processed.
Throws:
StoreError: if there is no matching delayed event, or if it has not yet been processed.
"""
return await self.db_pool.simple_delete_one(
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": True,
},
desc="delete_processed_delayed_event",
)
async def delete_processed_delayed_state_events(
self,
*,
room_id: str,
event_type: str,
state_key: str,
) -> None:
"""
Delete the matching delayed state events that have been marked as processed.
"""
await self.db_pool.simple_delete(
table="delayed_events",
keyvalues={
"room_id": room_id,
"event_type": event_type,
"state_key": state_key,
"is_processed": True,
},
desc="delete_processed_delayed_state_events",
)
async def unprocess_delayed_events(self) -> None:
"""
Unmark all delayed events for processing.
"""
await self.db_pool.simple_update(
table="delayed_events",
keyvalues={"is_processed": True},
updatevalues={"is_processed": False},
desc="unprocess_delayed_events",
)
async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]:
"""
Returns the send time of the next delayed event to be sent, if any.
"""
return await self.db_pool.runInteraction(
"get_next_delayed_event_send_ts",
self._get_next_delayed_event_send_ts_txn,
db_autocommit=True,
)
def _get_next_delayed_event_send_ts_txn(
self, txn: LoggingTransaction
) -> Optional[Timestamp]:
result = self.db_pool.simple_select_one_onecol_txn(
txn,
table="delayed_events",
keyvalues={"is_processed": False},
retcol="MIN(send_ts)",
allow_none=True,
)
return Timestamp(result) if result is not None else None
def _generate_delay_id() -> DelayID:
"""Generates an opaque string, for use as a delay ID"""
# We use the following format for delay IDs:
# syd_<random string>
# They are scoped to user localparts, so it is possible for
# the same ID to exist for multiple users.
return DelayID(f"syd_{stringutils.random_string(20)}")

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 87 # remember to update the list below when updating
SCHEMA_VERSION = 88 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -149,6 +149,10 @@ Changes in SCHEMA_VERSION = 87
- Add tables for storing the per-connection state for sliding sync requests:
sliding_sync_connections, sliding_sync_connection_positions, sliding_sync_connection_required_state,
sliding_sync_connection_room_configs, sliding_sync_connection_streams
Changes in SCHEMA_VERSION = 88
- MSC4140: Add `delayed_events` table that keeps track of events that are to
be posted in response to a resettable timeout or an on-demand action.
"""

View file

@ -0,0 +1,30 @@
CREATE TABLE delayed_events (
delay_id TEXT NOT NULL,
user_localpart TEXT NOT NULL,
device_id TEXT,
delay BIGINT NOT NULL,
send_ts BIGINT NOT NULL,
room_id TEXT NOT NULL,
event_type TEXT NOT NULL,
state_key TEXT,
origin_server_ts BIGINT,
content bytea NOT NULL,
is_processed BOOLEAN NOT NULL DEFAULT FALSE,
PRIMARY KEY (user_localpart, delay_id)
);
CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts);
CREATE INDEX delayed_events_is_processed ON delayed_events (is_processed);
CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL;
CREATE TABLE delayed_events_stream_pos (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT NOT NULL,
CHECK (Lock='X')
);
-- Start processing events from the point this migration was run, rather
-- than the beginning of time.
INSERT INTO delayed_events_stream_pos (
stream_id
) SELECT COALESCE(MAX(stream_ordering), 0) from events;

View file

@ -0,0 +1,346 @@
"""Tests REST events for /delayed_events paths."""
from http import HTTPStatus
from typing import List
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest.client import delayed_events, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
_HS_NAME = "red"
_EVENT_TYPE = "com.example.test"
class DelayedEventsTestCase(HomeserverTestCase):
"""Tests getting and managing delayed events."""
servlets = [delayed_events.register_servlets, room.register_servlets]
user_id = f"@sid1:{_HS_NAME}"
def default_config(self) -> JsonDict:
config = super().default_config()
config["server_name"] = _HS_NAME
config["max_event_delay_duration"] = "24h"
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(
self.user_id,
extra_content={
"preset": "trusted_private_chat",
},
)
def test_delayed_events_empty_on_startup(self) -> None:
self.assertListEqual([], self._get_delayed_events())
def test_delayed_state_events_are_sent_on_timeout(self) -> None:
state_key = "to_send_on_timeout"
setter_key = "setter"
setter_expected = "on_timeout"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
{
setter_key: setter_expected,
},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
content = self._get_delayed_event_content(events[0])
self.assertEqual(setter_expected, content.get(setter_key), content)
self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
self.reactor.advance(1)
self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_update_delayed_event_without_id(self) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/",
)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
def test_update_delayed_event_without_body(self) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
Codes.NOT_JSON,
channel.json_body["errcode"],
)
def test_update_delayed_event_without_action(self) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
{},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
Codes.MISSING_PARAM,
channel.json_body["errcode"],
)
def test_update_delayed_event_with_invalid_action(self) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
{"action": "oops"},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
Codes.INVALID_PARAM,
channel.json_body["errcode"],
)
@parameterized.expand(["cancel", "restart", "send"])
def test_update_delayed_event_without_match(self, action: str) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
{"action": action},
)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
def test_cancel_delayed_state_event(self) -> None:
state_key = "to_never_send"
setter_key = "setter"
setter_expected = "none"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
{
setter_key: setter_expected,
},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
content = self._get_delayed_event_content(events[0])
self.assertEqual(setter_expected, content.get(setter_key), content)
self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "cancel"},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
self.reactor.advance(1)
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
def test_send_delayed_state_event(self) -> None:
state_key = "to_send_on_request"
setter_key = "setter"
setter_expected = "on_send"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 100000),
{
setter_key: setter_expected,
},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
content = self._get_delayed_event_content(events[0])
self.assertEqual(setter_expected, content.get(setter_key), content)
self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "send"},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_restart_delayed_state_event(self) -> None:
state_key = "to_send_on_restarted_timeout"
setter_key = "setter"
setter_expected = "on_timeout"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
{
setter_key: setter_expected,
},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
content = self._get_delayed_event_content(events[0])
self.assertEqual(setter_expected, content.get(setter_key), content)
self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "restart"},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.reactor.advance(1)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
content = self._get_delayed_event_content(events[0])
self.assertEqual(setter_expected, content.get(setter_key), content)
self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND,
)
self.reactor.advance(1)
self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None:
state_key = "to_be_cancelled"
setter_key = "setter"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
{
setter_key: "on_timeout",
},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
setter_expected = "manual"
self.helper.send_state(
self.room_id,
_EVENT_TYPE,
{
setter_key: setter_expected,
},
None,
state_key=state_key,
)
self.assertListEqual([], self._get_delayed_events())
self.reactor.advance(1)
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
"",
state_key=state_key,
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def _get_delayed_events(self) -> List[JsonDict]:
channel = self.make_request(
"GET",
PATH_PREFIX,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
key = "delayed_events"
self.assertIn(key, channel.json_body)
events = channel.json_body[key]
self.assertIsInstance(events, list)
return events
def _get_delayed_event_content(self, event: JsonDict) -> JsonDict:
key = "content"
self.assertIn(key, event)
content = event[key]
self.assertIsInstance(content, dict)
return content
def _get_path_for_delayed_state(
room_id: str, event_type: str, state_key: str, delay_ms: int
) -> str:
return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}"

View file

@ -2291,6 +2291,106 @@ class RoomMessageFilterTestCase(RoomBase):
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
class RoomDelayedEventTestCase(RoomBase):
"""Tests delayed events."""
user_id = "@sid1:red"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
@unittest.override_config({"max_event_delay_duration": "24h"})
def test_send_delayed_invalid_event(self) -> None:
"""Test sending a delayed event with invalid content."""
channel = self.make_request(
"PUT",
(
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
% self.room_id
).encode("ascii"),
{},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertNotIn("org.matrix.msc4140.errcode", channel.json_body)
def test_delayed_event_unsupported_by_default(self) -> None:
"""Test that sending a delayed event is unsupported with the default config."""
channel = self.make_request(
"PUT",
(
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
% self.room_id
).encode("ascii"),
{"body": "test", "msgtype": "m.text"},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
"M_MAX_DELAY_UNSUPPORTED",
channel.json_body.get("org.matrix.msc4140.errcode"),
channel.json_body,
)
@unittest.override_config({"max_event_delay_duration": "1000"})
def test_delayed_event_exceeds_max_delay(self) -> None:
"""Test that sending a delayed event fails if its delay is longer than allowed."""
channel = self.make_request(
"PUT",
(
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
% self.room_id
).encode("ascii"),
{"body": "test", "msgtype": "m.text"},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
"M_MAX_DELAY_EXCEEDED",
channel.json_body.get("org.matrix.msc4140.errcode"),
channel.json_body,
)
@unittest.override_config({"max_event_delay_duration": "24h"})
def test_delayed_event_with_negative_delay(self) -> None:
"""Test that sending a delayed event fails if its delay is negative."""
channel = self.make_request(
"PUT",
(
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=-2000"
% self.room_id
).encode("ascii"),
{"body": "test", "msgtype": "m.text"},
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
Codes.INVALID_PARAM, channel.json_body["errcode"], channel.json_body
)
@unittest.override_config({"max_event_delay_duration": "24h"})
def test_send_delayed_message_event(self) -> None:
"""Test sending a valid delayed message event."""
channel = self.make_request(
"PUT",
(
"rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
% self.room_id
).encode("ascii"),
{"body": "test", "msgtype": "m.text"},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
@unittest.override_config({"max_event_delay_duration": "24h"})
def test_send_delayed_state_event(self) -> None:
"""Test sending a valid delayed state event."""
channel = self.make_request(
"PUT",
(
"rooms/%s/state/m.room.topic/?org.matrix.msc4140.delay=2000"
% self.room_id
).encode("ascii"),
{"topic": "This is a topic"},
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,