0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-09 14:18:54 +02:00

Add a module API to allow modules to edit push rule actions (#12406)

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
This commit is contained in:
Brendan Abolivier 2022-04-27 15:55:33 +02:00 committed by GitHub
parent d743b25c8f
commit 5ef673de4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 319 additions and 104 deletions

View file

@ -0,0 +1 @@
Add a module API to allow modules to change actions for existing push rules of local users.

View file

@ -0,0 +1,138 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Union
import attr
from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.baserules import BASE_RULE_IDS
from synapse.storage.push_rule import RuleNotFoundException
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RuleSpec:
scope: str
template: str
rule_id: str
attr: Optional[str]
class PushRulesHandler:
"""A class to handle changes in push rules for users."""
def __init__(self, hs: "HomeServer"):
self._notifier = hs.get_notifier()
self._main_store = hs.get_datastores().main
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
"""Set an attribute (enabled or actions) on an existing push rule.
Notifies listeners (e.g. sync handler) of the change.
Args:
user_id: the user for which to modify the push rule.
spec: the spec of the push rule to modify.
val: the value to change the attribute to.
Raises:
RuleNotFoundException if the rule being modified doesn't exist.
SynapseError(400) if the value is malformed.
UnrecognizedRequestError if the attribute to change is unknown.
InvalidRuleException if we're trying to change the actions on a rule but
the provided actions aren't compliant with the spec.
"""
if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,))
if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
# Legacy fallback
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
await self._main_store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
elif spec.attr == "actions":
if not isinstance(val, dict):
raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
if not isinstance(actions, list):
raise SynapseError(400, "Value for 'actions' must be dict")
check_actions(actions)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise RuleNotFoundException(
"Unknown rule %r" % (namespaced_rule_id,)
)
await self._main_store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
self.notify_user(user_id)
def notify_user(self, user_id: str) -> None:
"""Notify listeners about a push rule change.
Args:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
def check_actions(actions: List[Union[str, JsonDict]]) -> None:
"""Check if the given actions are spec compliant.
Args:
actions: the actions to check.
Raises:
InvalidRuleException if the rules aren't compliant with the spec.
"""
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions:
if a in ["notify", "dont_notify", "coalesce"]:
pass
elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action %s" % a)
class InvalidRuleException(Exception):
pass

View file

@ -82,6 +82,7 @@ from synapse.handlers.auth import (
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeHtmlResource,
@ -195,6 +196,7 @@ class ModuleApi:
self._clock: Clock = hs.get_clock()
self._registration_handler = hs.get_registration_handler()
self._send_email_handler = hs.get_send_email_handler()
self._push_rules_handler = hs.get_push_rules_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
try:
@ -1352,6 +1354,68 @@ class ModuleApi:
"""
await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
def check_push_rule_actions(
self, actions: List[Union[str, Dict[str, str]]]
) -> None:
"""Checks if the given push rule actions are valid according to the Matrix
specification.
See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid
actions.
Added in Synapse v1.58.0.
Args:
actions: the actions to check.
Raises:
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
"""
check_actions(actions)
async def set_push_rule_action(
self,
user_id: str,
scope: str,
kind: str,
rule_id: str,
actions: List[Union[str, Dict[str, str]]],
) -> None:
"""Changes the actions of an existing push rule for the given user.
See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more
information about push rules and their syntax.
Can only be called on the main process.
Added in Synapse v1.58.0.
Args:
user_id: the user for which to change the push rule's actions.
scope: the push rule's scope, currently only "global" is allowed.
kind: the push rule's kind.
rule_id: the push rule's identifier.
actions: the actions to run when the rule's conditions match.
Raises:
RuntimeError if this method is called on a worker or `scope` is invalid.
synapse.module_api.errors.RuleNotFoundException if the rule being modified
can't be found.
synapse.module_api.errors.InvalidRuleException if the actions are invalid.
"""
if self.worker_app is not None:
raise RuntimeError("module tried to change push rule actions on a worker")
if scope != "global":
raise RuntimeError(
"invalid scope %s, only 'global' is currently allowed" % scope
)
spec = RuleSpec(scope, kind, rule_id, "actions")
await self._push_rules_handler.set_rule_attr(
user_id, spec, {"actions": actions}
)
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room

View file

@ -20,10 +20,14 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
from synapse.handlers.push_rules import InvalidRuleException
from synapse.storage.push_rule import RuleNotFoundException
__all__ = [
"InvalidClientCredentialsError",
"RedirectException",
"SynapseError",
"ConfigError",
"InvalidRuleException",
"RuleNotFoundException",
]

View file

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
import attr
from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@ -22,6 +20,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@ -29,7 +28,6 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
@ -40,14 +38,6 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RuleSpec:
scope: str
template: str
rule_id: str
attr: Optional[str]
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet):
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
self._push_rules_handler = hs.get_push_rules_handler()
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string()
if spec.attr:
await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
try:
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
except InvalidRuleException as e:
raise SynapseError(400, "Invalid actions: %s" % e)
except RuleNotFoundException:
raise NotFoundError("Unknown rule")
return 200, {}
if spec.rule_id.startswith("."):
@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet):
before = parse_string(request, "before")
if before:
before = _namespaced_rule_id(spec, before)
before = f"global/{spec.template}/{before}"
after = parse_string(request, "after")
if after:
after = _namespaced_rule_id(spec, after)
after = f"global/{spec.template}/{after}"
try:
await self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
rule_id=f"global/{spec.template}/{spec.rule_id}",
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after,
)
self.notify_user(user_id)
self._push_rules_handler.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, str(e))
except RuleNotFoundException as e:
@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:
await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
self._push_rules_handler.notify_user(user_id)
return 200, {}
except StoreError as e:
if e.code == 404:
@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet):
else:
raise UnrecognizedRequestError()
def notify_user(self, user_id: str) -> None:
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
# Legacy fallback
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
elif spec.attr == "actions":
if not isinstance(val, dict):
raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
if not isinstance(actions, list):
raise SynapseError(400, "Value for 'actions' must be dict")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
@ -291,24 +238,11 @@ def _rule_tuple_from_request_object(
raise InvalidRuleException("No actions found")
actions = req_obj["actions"]
_check_actions(actions)
check_actions(actions)
return conditions, actions
def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
for a in actions:
if a in ["notify", "dont_notify", "coalesce"]:
pass
elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int:
return pc
def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
return _namespaced_rule_id(spec, spec.rule_id)
def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
return "global/%s/%s" % (spec.template, rule_id)
class InvalidRuleException(Exception):
pass
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)

View file

@ -91,6 +91,7 @@ from synapse.handlers.presence import (
WorkerPresenceHandler,
)
from synapse.handlers.profile import ProfileHandler
from synapse.handlers.push_rules import PushRulesHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
@ -810,6 +811,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_handler(self) -> AccountHandler:
return AccountHandler(self)
@cache_in_self
def get_push_rules_handler(self) -> PushRulesHandler:
return PushRulesHandler(self)
@cache_in_self
def get_outbound_redis_connection(self) -> "ConnectionHandler":
"""

View file

@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import NotFoundError, StoreError
from synapse.api.errors import StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore):
are always stored in the database `push_rules` table).
Raises:
NotFoundError if the rule does not exist.
RuleNotFoundException if the rule does not exist.
"""
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
txn.execute(sql, (user_id, rule_id))
if txn.fetchone() is None:
# needed to set NOT_FOUND code.
raise NotFoundError("Push rule does not exist.")
raise RuleNotFoundException("Push rule does not exist.")
self.db_pool.simple_upsert_txn(
txn,
@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore):
"""
Sets the `actions` state of a push rule.
Will throw NotFoundError if the rule does not exist; the Code for this
is NOT_FOUND.
Args:
user_id: the user ID of the user who wishes to enable/disable the rule
e.g. '@tina:example.org'
@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore):
is_default_rule: True if and only if this is a server-default rule.
This skips the check for existence (as only user-created rules
are always stored in the database `push_rules` table).
Raises:
RuleNotFoundException if the rule does not exist.
"""
actions_json = json_encoder.encode(actions)
@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore):
except StoreError as serr:
if serr.code == 404:
# this sets the NOT_FOUND error Code
raise NotFoundError("Push rule does not exist")
raise RuleNotFoundException("Push rule does not exist")
else:
raise

View file

@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
from synapse.rest import admin
from synapse.rest.client import login, presence, profile, room
from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room.register_servlets,
presence.register_servlets,
profile.register_servlets,
notifications.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
def test_set_push_rules_action(self) -> None:
"""Test that a module can change the actions of an existing push rule for a user."""
# Create a room with 2 users in it. Push rules must not match if the user is the
# event's sender, so we need one user to send messages and one user to receive
# notifications.
user_id = self.register_user("user", "password")
tok = self.login("user", "password")
room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok)
user_id2 = self.register_user("user2", "password")
tok2 = self.login("user2", "password")
self.helper.join(room_id, user_id2, tok=tok2)
# Register a 3rd user and join them to the room, so that we don't accidentally
# trigger 1:1 push rules.
user_id3 = self.register_user("user3", "password")
tok3 = self.login("user3", "password")
self.helper.join(room_id, user_id3, tok=tok3)
# Send a message as the second user and check that it notifies.
res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2)
event_id = res["event_id"]
channel = self.make_request(
"GET",
"/notifications",
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
self.assertEqual(
channel.json_body["notifications"][0]["event"]["event_id"],
event_id,
channel.json_body,
)
# Change the .m.rule.message actions to not notify on new messages.
self.get_success(
defer.ensureDeferred(
self.module_api.set_push_rule_action(
user_id=user_id,
scope="global",
kind="underride",
rule_id=".m.rule.message",
actions=["dont_notify"],
)
)
)
# Send another message as the second user and check that the number of
# notifications didn't change.
self.helper.send(room_id=room_id, body="here's another message", tok=tok2)
channel = self.make_request(
"GET",
"/notifications?from=",
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
def test_check_push_rules_actions(self) -> None:
"""Test that modules can check whether a list of push rules actions are spec
compliant.
"""
with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
self.module_api.check_push_rule_actions({"foo": "bar"})
self.module_api.check_push_rule_actions(["notify"])
self.module_api.check_push_rule_actions(
[{"set_tweak": "sound", "value": "default"}]
)
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""