From 119e0795a58548fb38fab299e7c362fcbb388d68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Feb 2023 14:02:19 -0500 Subject: [PATCH] Implement MSC3966: Add a push rule condition to search for a value in an array. (#15045) The `exact_event_property_contains` condition can be used to search for a value inside of an array. --- changelog.d/15045.feature | 1 + rust/benches/evaluator.rs | 32 +++++++----- rust/src/push/evaluator.rs | 65 +++++++++++++++++++----- rust/src/push/mod.rs | 33 +++++++++++- stubs/synapse/synapse_rust/push.pyi | 7 +-- synapse/config/experimental.py | 5 ++ synapse/push/bulk_push_rule_evaluator.py | 21 +++++--- synapse/types/__init__.py | 1 + tests/push/test_push_rule_evaluator.py | 53 +++++++++++++++++-- 9 files changed, 176 insertions(+), 42 deletions(-) create mode 100644 changelog.d/15045.feature diff --git a/changelog.d/15045.feature b/changelog.d/15045.feature new file mode 100644 index 000000000..87766befd --- /dev/null +++ b/changelog.d/15045.feature @@ -0,0 +1 @@ +Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 229553ebf..8213dfd9e 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -15,8 +15,8 @@ #![feature(test)] use std::collections::BTreeSet; use synapse::push::{ - evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, - SimpleJsonValue, + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue, + PushRules, SimpleJsonValue, }; use test::Bencher; @@ -27,15 +27,15 @@ fn bench_match_exact(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -54,6 +54,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -76,15 +77,15 @@ fn bench_match_word(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -103,6 +104,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -125,15 +127,15 @@ fn bench_match_word_miss(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -152,6 +154,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -174,15 +177,15 @@ fn bench_eval_message(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -201,6 +204,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index dd6b4343e..2eaa06ad7 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -14,6 +14,7 @@ use std::collections::{BTreeMap, BTreeSet}; +use crate::push::JsonValue; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -63,7 +64,7 @@ impl RoomVersionFeatures { pub struct PushRuleEvaluator { /// A mapping of "flattened" keys to simple JSON values in the event, e.g. /// includes things like "type" and "content.msgtype". - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, /// The "content.body", if any. body: String, @@ -87,7 +88,7 @@ pub struct PushRuleEvaluator { /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, @@ -101,6 +102,9 @@ pub struct PushRuleEvaluator { /// If MSC3758 (exact_event_match push rule condition) is enabled. msc3758_exact_event_match: bool, + + /// If MSC3966 (exact_event_property_contains push rule condition) is enabled. + msc3966_exact_event_property_contains: bool, } #[pymethods] @@ -109,21 +113,22 @@ impl PushRuleEvaluator { #[allow(clippy::too_many_arguments)] #[new] pub fn py_new( - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ) -> Result { let body = match flattened_keys.get("content.body") { - Some(SimpleJsonValue::Str(s)) => s.clone(), + Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(), _ => String::new(), }; @@ -141,6 +146,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc3758_exact_event_match, + msc3966_exact_event_property_contains, }) } @@ -263,6 +269,9 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::ExactEventPropertyContains(exact_event_match) => { + self.match_exact_event_property_contains(exact_event_match)? + } KnownCondition::IsUserMention => { if let Some(uid) = user_id { self.user_mentions.contains(uid) @@ -345,7 +354,7 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = + let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = self.flattened_keys.get(&*event_match.key) { haystack @@ -377,7 +386,9 @@ impl PushRuleEvaluator { let value = &exact_event_match.value; - let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) { + let haystack = if let Some(JsonValue::Value(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { haystack } else { return Ok(false); @@ -441,11 +452,12 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) { - haystack - } else { - return Ok(false); - }; + let haystack = + if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; // For the content.body we match against "words", but for everything // else we match against the entire value. @@ -459,6 +471,29 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `exact_event_property_contains` condition. (MSC3758) + fn match_exact_event_property_contains( + &self, + exact_event_match: &ExactEventMatchCondition, + ) -> Result { + // First check if the feature is enabled. + if !self.msc3966_exact_event_property_contains { + return Ok(false); + } + + let value = &exact_event_match.value; + + let haystack = if let Some(JsonValue::Array(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { + haystack + } else { + return Ok(false); + }; + + Ok(haystack.contains(&**value)) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -488,7 +523,7 @@ fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let evaluator = PushRuleEvaluator::py_new( flattened_keys, @@ -503,6 +538,7 @@ fn push_rule_evaluator() { vec![], true, true, + true, ) .unwrap(); @@ -519,7 +555,7 @@ fn test_requires_room_version_supports_condition() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( @@ -535,6 +571,7 @@ fn test_requires_room_version_supports_condition() { flags, true, true, + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 79e519fe1..253b5f367 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -58,7 +58,7 @@ use anyhow::{Context, Error}; use log::warn; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyLong, PyString}; +use pyo3::types::{PyBool, PyList, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -280,6 +280,35 @@ impl<'source> FromPyObject<'source> for SimpleJsonValue { } } +/// A JSON values (list, string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum JsonValue { + Array(Vec), + Value(SimpleJsonValue), +} + +impl<'source> FromPyObject<'source> for JsonValue { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(l) = ::try_from(ob) { + match l.iter().map(SimpleJsonValue::extract).collect() { + Ok(a) => Ok(JsonValue::Array(a)), + Err(e) => Err(PyTypeError::new_err(format!( + "Can't convert to JsonValue::Array: {}", + e + ))), + } + } else if let Ok(v) = SimpleJsonValue::extract(ob) { + Ok(JsonValue::Value(v)) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to JsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -303,6 +332,8 @@ pub enum KnownCondition { ExactEventMatch(ExactEventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")] + ExactEventPropertyContains(ExactEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] IsUserMention, #[serde(rename = "org.matrix.msc3952.is_room_mention")] diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 328f681a2..7b33c30cc 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -14,7 +14,7 @@ from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union -from synapse.types import JsonDict, SimpleJsonValue +from synapse.types import JsonDict, JsonValue class PushRule: @property @@ -56,18 +56,19 @@ def get_base_rule_ids() -> Collection[str]: ... class PushRuleEvaluator: def __init__( self, - flattened_keys: Mapping[str, SimpleJsonValue], + flattened_keys: Mapping[str, JsonValue], has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]], + related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6ac2f0c10..1d294f879 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -188,3 +188,8 @@ class ExperimentalConfig(Config): self.msc3958_supress_edit_notifs = experimental.get( "msc3958_supress_edit_notifs", False ) + + # MSC3966: exact_event_property_contains push rule condition. + self.msc3966_exact_event_property_contains = experimental.get( + "msc3966_exact_event_property_contains", False + ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f6a5bffb0..2e917c90c 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -44,7 +44,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator -from synapse.types import SimpleJsonValue +from synapse.types import JsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -259,13 +259,13 @@ class BulkPushRuleEvaluator: async def _related_events( self, event: EventBase - ) -> Dict[str, Dict[str, SimpleJsonValue]]: + ) -> Dict[str, Dict[str, JsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} + related_events: Dict[str, Dict[str, JsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -429,6 +429,7 @@ class BulkPushRuleEvaluator: event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc3758_exact_event_match, + self.hs.config.experimental.msc3966_exact_event_property_contains, ) users = rules_by_user.keys() @@ -502,18 +503,22 @@ RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] +def _is_simple_value(value: Any) -> bool: + return isinstance(value, (bool, str)) or type(value) is int or value is None + + def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, SimpleJsonValue]] = None, + result: Optional[Dict[str, JsonValue]] = None, *, msc3783_escape_event_match_key: bool = False, -) -> Dict[str, SimpleJsonValue]: +) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - String, integer, boolean, and null values are kept. All others are dropped. + String, integer, boolean, null or lists of those values are kept. All others are dropped. Transforms: @@ -542,8 +547,10 @@ def _flatten_dict( # nested fields. key = key.replace("\\", "\\\\").replace(".", "\\.") - if isinstance(value, (bool, str)) or type(value) is int or value is None: + if _is_simple_value(value): result[".".join(prefix + [key])] = value + elif isinstance(value, (list, tuple)): + result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)] elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below _flatten_dict( diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 52e366c8a..33363867c 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -71,6 +71,7 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A "simple" (canonical) JSON value. SimpleJsonValue = Optional[Union[str, int, bool]] +JsonValue = Union[List[SimpleJsonValue], Tuple[SimpleJsonValue, ...], SimpleJsonValue] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 660344734..0554d247b 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict, JsonMapping, UserID from synapse.util import Clock +from synapse.util.frozenutils import freeze from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event @@ -57,17 +58,24 @@ class FlattenDictTestCase(unittest.TestCase): ) def test_non_string(self) -> None: - """Booleans, ints, and nulls should be kept while other items are dropped.""" + """String, booleans, ints, nulls and list of those should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, "bar": 1, "baz": None, - "fuzz": [], + "fuzz": ["woo", True, 1, None, [], {}], "boo": {}, } self.assertEqual( - {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input) + { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": ["woo", True, 1, None], + }, + _flatten_dict(input), ) def test_event(self) -> None: @@ -117,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -128,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -169,6 +179,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc3758_exact_event_match=True, + msc3966_exact_event_property_contains=True, ) def test_display_name(self) -> None: @@ -549,6 +560,42 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "incorrect types should not match", ) + def test_exact_event_property_contains(self) -> None: + """Check that exact_event_property_contains conditions work as expected.""" + + condition = { + "kind": "org.matrix.msc3966.exact_event_property_contains", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": ["foobaz"]}, + "exact value should match", + ) + self._assert_matches( + condition, + {"value": ["foobaz", "bugz"]}, + "extra values should match", + ) + self._assert_not_matches( + condition, + {"value": ["FoobaZ"]}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "foobaz"}, + "does not search in a string", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + freeze({"value": ["foobaz"]}), + "values should match on frozendicts", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({})