Experimental support for MSC3772 (#12740)

Implements the following behind an experimental configuration flag:

* A new push rule kind for mutually related events.
* A new default push rule (`.m.rule.thread_reply`) under an unstable prefix.

This is missing part of MSC3772:

* The `.m.rule.thread_reply_to_me` push rule, this depends on MSC3664 / #11804.
This commit is contained in:
Patrick Cloke 2022-05-24 09:23:23 -04:00 committed by GitHub
parent 0b3423fd51
commit 88ce3080d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 287 additions and 6 deletions

View file

@ -0,0 +1 @@
Experimental support for [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772): Push rule for mutually related events.

View file

@ -84,3 +84,6 @@ class ExperimentalConfig(Config):
# MSC3786 (Add a default push rule to ignore m.room.server_acl events) # MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
# MSC3772: A push rule for mutual relations.
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)

View file

@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
{ {
"kind": "event_match", "kind": "event_match",
"key": "content.body", "key": "content.body",
# Match the localpart of the requester's MXID.
"pattern_type": "user_localpart", "pattern_type": "user_localpart",
} }
], ],
@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
"pattern": "invite", "pattern": "invite",
"_cache_key": "_invite_member", "_cache_key": "_invite_member",
}, },
# Match the requester's MXID.
{"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
], ],
"actions": [ "actions": [
@ -350,6 +352,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
{"set_tweak": "highlight", "value": False}, {"set_tweak": "highlight", "value": False},
], ],
}, },
{
"rule_id": "global/underride/.org.matrix.msc3772.thread_reply",
"conditions": [
{
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.thread",
# Match the requester's MXID.
"sender_type": "user_id",
}
],
"actions": ["notify", {"set_tweak": "highlight", "value": False}],
},
{ {
"rule_id": "global/underride/.m.rule.message", "rule_id": "global/underride/.m.rule.message",
"conditions": [ "conditions": [

View file

@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -121,6 +122,9 @@ class BulkPushRuleEvaluator:
resizable=False, resizable=False,
) )
# Whether to support MSC3772 is supported.
self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
async def _get_rules_for_event( async def _get_rules_for_event(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@ -192,6 +196,60 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations(
self, event: EventBase, rules: Iterable[Dict[str, Any]]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
If the given event has no relation information, returns an empty dictionary.
Args:
event_id: The event ID which is targeted by relations.
rules: The push rules which will be processed for this event.
Returns:
A dictionary of relation type to:
A set of tuples of:
The sender
The event type
"""
# If the experimental feature is not enabled, skip fetching relations.
if not self._relations_match_enabled:
return {}
# If the event does not have a relation, then cannot have any mutual
# relations.
relation = relation_from_event(event)
if not relation:
return {}
# Pre-filter to figure out which relation types are interesting.
rel_types = set()
for rule in rules:
# Skip disabled rules.
if "enabled" in rule and not rule["enabled"]:
continue
for condition in rule["conditions"]:
if condition["kind"] != "org.matrix.msc3772.relation_match":
continue
# rel_type is required.
rel_type = condition.get("rel_type")
if rel_type:
rel_types.add(rel_type)
# If no valid rules were found, no mutual relations.
if not rel_types:
return {}
# If any valid rules were found, fetch the mutual relations.
return await self.store.get_mutual_event_relations(
relation.parent_id, rel_types
)
@measure_func("action_for_event_by_user") @measure_func("action_for_event_by_user")
async def action_for_event_by_user( async def action_for_event_by_user(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
@ -216,8 +274,17 @@ class BulkPushRuleEvaluator:
sender_power_level, sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context) ) = await self._get_power_levels_and_sender_level(event, context)
relations = await self._get_mutual_relations(
event, itertools.chain(*rules_by_user.values())
)
evaluator = PushRuleEvaluatorForEvent( evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels event,
len(room_members),
sender_power_level,
power_levels,
relations,
self._relations_match_enabled,
) )
# If the event is not a state event check if any users ignore the sender. # If the event is not a state event check if any users ignore the sender.

View file

@ -48,6 +48,10 @@ def format_push_rules_for_user(
elif pattern_type == "user_localpart": elif pattern_type == "user_localpart":
c["pattern"] = user.localpart c["pattern"] = user.localpart
sender_type = c.pop("sender_type", None)
if sender_type == "user_id":
c["sender"] = user.to_string()
rulearray = rules["global"][template_name] rulearray = rules["global"][template_name]
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)

View file

@ -15,7 +15,7 @@
import logging import logging
import re import re
from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union
from matrix_common.regex import glob_to_regex, to_word_pattern from matrix_common.regex import glob_to_regex, to_word_pattern
@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent:
room_member_count: int, room_member_count: int,
sender_power_level: int, sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]], power_levels: Dict[str, Union[int, Dict[str, int]]],
relations: Dict[str, Set[Tuple[str, str]]],
relations_match_enabled: bool,
): ):
self._event = event self._event = event
self._room_member_count = room_member_count self._room_member_count = room_member_count
self._sender_power_level = sender_power_level self._sender_power_level = sender_power_level
self._power_levels = power_levels self._power_levels = power_levels
self._relations = relations
self._relations_match_enabled = relations_match_enabled
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent:
return _sender_notification_permission( return _sender_notification_permission(
self._event, condition, self._sender_power_level, self._power_levels self._event, condition, self._sender_power_level, self._power_levels
) )
elif (
condition["kind"] == "org.matrix.msc3772.relation_match"
and self._relations_match_enabled
):
return self._relation_match(condition, user_id)
else: else:
# XXX This looks incorrect -- we have reached an unknown condition
# kind and are unconditionally returning that it matches. Note
# that it seems possible to provide a condition to the /pushrules
# endpoint with an unknown kind, see _rule_tuple_from_request_object.
return True return True
def _event_match(self, condition: dict, user_id: str) -> bool: def _event_match(self, condition: dict, user_id: str) -> bool:
@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent:
return bool(r.search(body)) return bool(r.search(body))
def _relation_match(self, condition: dict, user_id: str) -> bool:
"""
Check an "relation_match" push rule condition.
Args:
condition: The "event_match" push rule condition to match.
user_id: The user's MXID.
Returns:
True if the condition matches the event, False otherwise.
"""
rel_type = condition.get("rel_type")
if not rel_type:
logger.warning("relation_match condition missing rel_type")
return False
sender_pattern = condition.get("sender")
if sender_pattern is None:
sender_type = condition.get("sender_type")
if sender_type == "user_id":
sender_pattern = user_id
type_pattern = condition.get("type")
# If any other relations matches, return True.
for sender, event_type in self._relations.get(rel_type, ()):
if sender_pattern and not _glob_matches(sender_pattern, sender):
continue
if type_pattern and not _glob_matches(type_pattern, event_type):
continue
# All values must have matched.
return True
# No relations matched.
return False
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(

View file

@ -1828,6 +1828,10 @@ class PersistEventsStore:
self.store.get_aggregation_groups_for_event.invalidate, self.store.get_aggregation_groups_for_event.invalidate,
(relation.parent_id,), (relation.parent_id,),
) )
txn.call_after(
self.store.get_mutual_event_relations_for_rel_type.invalidate,
(relation.parent_id,),
)
if relation.rel_type == RelationTypes.REPLACE: if relation.rel_type == RelationTypes.REPLACE:
txn.call_after( txn.call_after(
@ -2004,6 +2008,11 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,) txn, self.store.get_thread_participated, (redacted_relates_to,)
) )
self.store._invalidate_cache_and_stream(
txn,
self.store.get_mutual_event_relations_for_rel_type,
(redacted_relates_to,),
)
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}

View file

@ -61,6 +61,11 @@ def _is_experimental_rule_enabled(
and not experimental_config.msc3786_enabled and not experimental_config.msc3786_enabled
): ):
return False return False
if (
rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
and not experimental_config.msc3772_enabled
):
return False
return True return True

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import defaultdict
from typing import ( from typing import (
Collection, Collection,
Dict, Dict,
@ -767,6 +768,57 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )
@cached(iterable=True)
async def get_mutual_event_relations_for_rel_type(
self, event_id: str, relation_type: str
) -> Set[Tuple[str, str]]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_mutual_event_relations_for_rel_type",
list_name="relation_types",
)
async def get_mutual_event_relations(
self, event_id: str, relation_types: Collection[str]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
If the given event has no relation information, returns an empty dictionary.
Args:
event_id: The event ID which is targeted by relations.
relation_types: The relation types to check for mutual relations.
Returns:
A dictionary of relation type to:
A set of tuples of:
The sender
The event type
"""
rel_type_sql, rel_type_args = make_in_list_sql_clause(
self.database_engine, "relation_type", relation_types
)
sql = f"""
SELECT DISTINCT relation_type, sender, type FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND {rel_type_sql}
"""
def _get_event_relations(
txn: LoggingTransaction,
) -> Dict[str, Set[Tuple[str, str]]]:
txn.execute(sql, [event_id] + rel_type_args)
result = defaultdict(set)
for rel_type, sender, type in txn.fetchall():
result[rel_type].add((sender, type))
return result
return await self.db_pool.runInteraction(
"get_event_relations", _get_event_relations
)
class RelationsStore(RelationsWorkerStore): class RelationsStore(RelationsWorkerStore):
pass pass

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Optional, Union from typing import Dict, Optional, Set, Tuple, Union
import frozendict import frozendict
@ -26,7 +26,12 @@ from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase): class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent: def _get_evaluator(
self,
content: JsonDict,
relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
relations_match_enabled: bool = False,
) -> PushRuleEvaluatorForEvent:
event = FrozenEvent( event = FrozenEvent(
{ {
"event_id": "$event_id", "event_id": "$event_id",
@ -42,7 +47,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
sender_power_level = 0 sender_power_level = 0
power_levels: Dict[str, Union[int, Dict[str, int]]] = {} power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent( return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels event,
room_member_count,
sender_power_level,
power_levels,
relations or set(),
relations_match_enabled,
) )
def test_display_name(self) -> None: def test_display_name(self) -> None:
@ -276,3 +286,71 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
push_rule_evaluator.tweaks_for_actions(actions), push_rule_evaluator.tweaks_for_actions(actions),
{"sound": "default", "highlight": True}, {"sound": "default", "highlight": True},
) )
def test_relation_match(self) -> None:
"""Test the relation_match push rule kind."""
# Check if the experimental feature is disabled.
evaluator = self._get_evaluator(
{}, {"m.annotation": {("@user:test", "m.reaction")}}
)
condition = {"kind": "relation_match"}
# Oddly, an unknown condition always matches.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# A push rule evaluator with the experimental rule enabled.
evaluator = self._get_evaluator(
{}, {"m.annotation": {("@user:test", "m.reaction")}}, True
)
# Check just relation type.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# Check relation type and sender.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"sender": "@user:test",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"sender": "@other:test",
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
# Check relation type and event type.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"type": "m.reaction",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# Check just sender, this fails since rel_type is required.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"sender": "@user:test",
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
# Check sender glob.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"sender": "@*:test",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
# Check event type glob.
condition = {
"kind": "org.matrix.msc3772.relation_match",
"rel_type": "m.annotation",
"event_type": "*.reaction",
}
self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))