0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 08:30:07 +01:00

Make notifications go quicker

This commit is contained in:
Erik Johnston 2016-01-18 14:09:47 +00:00
parent 2068678b8c
commit f59b564507
4 changed files with 258 additions and 129 deletions

View file

@ -14,16 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson as json import ujson as json
from twisted.internet import defer from twisted.internet import defer
import baserules
from push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.types import UserID from synapse.types import UserID
import baserules
from push_rule_evaluator import PushRuleEvaluator
from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,28 +35,25 @@ def decode_rule_json(rule):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_room_id(room_id, store): def _get_rules(room_id, user_ids, store):
users = yield store.get_users_in_room(room_id) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = yield store.bulk_get_push_rules(users)
rules_by_user = { rules_by_user = {
uid: baserules.list_with_base_rules( uid: baserules.list_with_base_rules(
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]] [decode_rule_json(rule_list) for rule_list in rules_by_user.get(uid, [])],
if uid in rules_by_user else [],
UserID.from_string(uid), UserID.from_string(uid),
) )
for uid in users for uid in user_ids
} }
member_events = yield store.get_current_state( defer.returnValue(rules_by_user)
room_id=room_id,
event_type='m.room.member',
) @defer.inlineCallbacks
display_names = {} def evaluator_for_room_id(room_id, store):
for ev in member_events: users = yield store.get_users_in_room(room_id)
if ev.content.get("displayname"): rules_by_user = yield _get_rules(room_id, users, store)
display_names[ev.state_key] = ev.content.get("displayname")
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, display_names, users, store room_id, rules_by_user, users, store
)) ))
@ -69,10 +66,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store): def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.display_names = display_names
self.users_in_room = users_in_room self.users_in_room = users_in_room
self.store = store self.store = store
@ -80,15 +76,30 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler): def action_for_event_by_user(self, event, handler):
actions_by_user = {} actions_by_user = {}
for uid, rules in self.rules_by_user.items(): users_dict = yield self.store.are_guests(self.rules_by_user.keys())
display_name = None
if uid in self.display_names:
display_name = self.display_names[uid]
is_guest = yield self.store.is_guest(UserID.from_string(uid)) filtered_by_user = yield handler._filter_events_for_clients(
filtered = yield handler._filter_events_for_client( users_dict.items(), [event]
uid, [event], is_guest=is_guest )
)
evaluator = PushRuleEvaluatorForEvent.create(event, len(self.users_in_room))
condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {}
for ev in member_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None)
filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:
continue continue
@ -96,29 +107,32 @@ class BulkPushRuleEvaluator:
if 'enabled' in rule and not rule['enabled']: if 'enabled' in rule and not rule['enabled']:
continue continue
# XXX: profile tags matches = _condition_checker(
if BulkPushRuleEvaluator.event_matches_rule( evaluator, rule['conditions'], display_name, condition_cache
event, rule, )
display_name, len(self.users_in_room), None if matches:
):
actions = [x for x in rule['actions'] if x != 'dont_notify'] actions = [x for x in rule['actions'] if x != 'dont_notify']
if len(actions) > 0: if actions:
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
defer.returnValue(actions_by_user) defer.returnValue(actions_by_user)
@staticmethod
def event_matches_rule(event, rule,
display_name, room_member_count, profile_tag):
matches = True
# passing the clock all the way into here is extremely awkward and push def _condition_checker(evaluator, conditions, display_name, cache):
# rules do not care about any of the relative timestamps, so we just for cond in conditions:
# pass 0 for the current time. _id = cond.get("_id", None)
client_event = serialize_event(event, 0) if _id:
res = cache.get(_id, None)
if res is False:
break
elif res is True:
continue
for cond in rule['conditions']: res = evaluator.matches(cond, display_name, None)
matches &= PushRuleEvaluator._event_fulfills_condition( if _id:
client_event, cond, display_name, room_member_count, profile_tag cache[_id] = res
)
return matches if res is False:
return False
return True

View file

@ -15,17 +15,22 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID
import baserules import baserules
import logging import logging
import simplejson as json import simplejson as json
import re import re
from synapse.types import UserID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id) rawrules = yield store.get_push_rules_for_user(user_id)
@ -42,9 +47,34 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
)) ))
def _room_member_count(ev, condition, room_member_count):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
class PushRuleEvaluator: class PushRuleEvaluator:
DEFAULT_ACTIONS = [] DEFAULT_ACTIONS = []
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store): our_member_event, store):
@ -98,6 +128,8 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id) room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members) room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent.create(ev, room_member_count)
for r in self.rules: for r in self.rules:
if r['rule_id'] in self.enabled_map: if r['rule_id'] in self.enabled_map:
r['enabled'] = self.enabled_map[r['rule_id']] r['enabled'] = self.enabled_map[r['rule_id']]
@ -105,21 +137,10 @@ class PushRuleEvaluator:
r['enabled'] = True r['enabled'] = True
if not r['enabled']: if not r['enabled']:
continue continue
matches = True
conditions = r['conditions'] conditions = r['conditions']
actions = r['actions'] actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count,
profile_tag=self.profile_tag
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify') # ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0: if len(actions) == 0:
logger.warn( logger.warn(
@ -127,6 +148,18 @@ class PushRuleEvaluator:
r['rule_id'], self.user_id r['rule_id'], self.user_id
) )
continue continue
matches = True
for c in conditions:
matches = evaluator.matches(c, my_display_name, self.profile_tag)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches: if matches:
logger.info( logger.info(
"%s matches for user %s, event %s", "%s matches for user %s, event %s",
@ -145,81 +178,84 @@ class PushRuleEvaluator:
) )
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges. class PushRuleEvaluatorForEvent(object):
r = re.sub(r'\\\[(\\\!|)(.*)\\\]', WORD_BOUNDARY = re.compile(r'\b')
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r) def __init__(self, event, body_parts, room_member_count):
return r self._event = event
self._body_parts = body_parts
self._room_member_count = room_member_count
self._value_cache = _flatten_dict(event)
@staticmethod @staticmethod
def _event_fulfills_condition(ev, condition, def create(event, room_member_count):
display_name, room_member_count, profile_tag): body = event.get("content", {}).get("body", None)
if body:
body_parts = PushRuleEvaluatorForEvent.WORD_BOUNDARY.split(body)
body_parts[:] = [
part.lower() for part in body_parts
]
else:
body_parts = []
return PushRuleEvaluatorForEvent(event, body_parts, room_member_count)
def matches(self, condition, display_name, profile_tag):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
if 'pattern' not in condition: return self._event_match(condition)
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device': elif condition['kind'] == 'device':
if 'profile_tag' not in condition: if 'profile_tag' not in condition:
return True return True
return condition['profile_tag'] == profile_tag return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different return self._contains_display_name(display_name)
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count': elif condition['kind'] == 'room_member_count':
if 'is' not in condition: return _room_member_count(
return False self._event, condition, self._room_member_count
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is']) )
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else: else:
return True return True
def _event_match(self, condition):
pattern = condition.get('pattern', None)
if not pattern:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
matcher = _glob_to_matcher(pattern)
for part in self._body_parts:
if matcher(part):
return True
return False
else:
haystack = self._get_value(condition['key'])
if haystack is None:
return False
matcher = _glob_to_matcher(pattern)
return matcher(haystack.lower())
def _contains_display_name(self, display_name):
if not display_name:
return False
lower_display_name = display_name.lower()
for part in self._body_parts:
if part == lower_display_name:
return True
return False
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
def _value_for_dotted_key(dotted_key, event): def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".") parts = dotted_key.split(".")
@ -229,4 +265,42 @@ def _value_for_dotted_key(dotted_key, event):
return None return None
val = val[parts[0]] val = val[parts[0]]
parts = parts[1:] parts = parts[1:]
return val return val
def _glob_to_matcher(glob):
glob = glob.lower()
if not IS_GLOB.search(glob):
return lambda value: value == glob
r = re.escape(glob)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
r = r + "$"
r = re.compile(r)
return lambda value: r.match(value)
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result)
return result

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -60,6 +60,27 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results r['rule_id']: False if r['enabled'] == 0 else True for r in results
}) })
@cached()
def _get_push_rules_enabled_for_user(self, user_id):
def f(txn):
sql = (
"SELECT pr.*"
" FROM push_rules AS pr"
" LEFT JOIN push_rules_enable AS pre"
" ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
" WHERE pr.user_name = ?"
" AND (pre.enabled IS NULL OR pre.enabled = 1)"
" ORDER BY pr.priority_class DESC, pr.priority DESC"
)
txn.execute(sql, (user_id,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"_get_push_rules_enabled_for_user", f
)
# @cachedList(cache=_get_push_rules_enabled_for_user.cache, list_name="user_ids",
# num_args=1, inlineCallbacks=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
if not user_ids: if not user_ids:

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user): def is_guest(self, user_id):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user_id},
retcol="is_guest", retcol="is_guest",
allow_none=True, allow_none=True,
desc="is_guest", desc="is_guest",
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
",".join("?" for _ in user_ids),
)
rows = yield self._execute(
"are_guests", self.cursor_to_dict, sql, *user_ids
)
result = {user_id: False for user_id in user_ids}
result.update({
row["name"]: bool(row["is_guest"])
for row in rows
})
defer.returnValue(result)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id"