Hook up the push rules to the notifier

This commit is contained in:
Mark Haines 2016-03-03 14:57:45 +00:00
parent 2223204eba
commit ddf9e7b302
5 changed files with 43 additions and 18 deletions

View file

@ -647,8 +647,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken(token[1], 0, 0, 0, 0) end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -284,7 +284,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """

View file

@ -36,6 +36,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
@ -51,8 +56,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
user_id = requester.user.to_string()
if 'attr' in spec: if 'attr' in spec:
yield self.set_rule_attr(requester.user.to_string(), spec, content) yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'): if spec['rule_id'].startswith('.'):
@ -77,8 +85,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.store.add_push_rule(
user_id=requester.user.to_string(), user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -86,6 +94,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before, before=before,
after=after after=after
) )
self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
except RuleNotFoundException as e: except RuleNotFoundException as e:
@ -98,13 +107,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.store.delete_push_rule(
requester.user.to_string(), namespaced_rule_id user_id, namespaced_rule_id
) )
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -115,14 +126,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( rawrules = yield self.store.get_push_rules_for_user(user_id)
user.to_string()
)
ruleslist = [] ruleslist = []
for rawrule in rawrules: for rawrule in rawrules:
@ -138,8 +147,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules['global'] = _add_empty_priority_class_arrays(rules['global']) rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\ enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist: for r in ruleslist:
rulearray = None rulearray = None
@ -152,9 +160,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
pattern_type = c.pop("pattern_type", None) pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id": if pattern_type == "user_id":
c["pattern"] = user.to_string() c["pattern"] = user_id
elif pattern_type == "user_localpart": elif pattern_type == "user_localpart":
c["pattern"] = user.localpart c["pattern"] = requester.user.localpart
rulearray = rules['global'][template_name] rulearray = rules['global'][template_name]
@ -188,6 +196,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def notify_user(self, user_id):
stream_id = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
@ -198,7 +212,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them. # bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.hs.get_datastore().set_push_rule_enabled( return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
elif spec['attr'] == 'actions': elif spec['attr'] == 'actions':
@ -210,7 +224,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if is_default_rule: if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS: if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.hs.get_datastore().set_push_rule_actions( return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule user_id, namespaced_rule_id, actions, is_default_rule
) )
else: else:

View file

@ -38,9 +38,12 @@ class EventSources(object):
name: cls(hs) name: cls(hs)
for name, cls in EventSources.SOURCE_TYPES.items() for name, cls in EventSources.SOURCE_TYPES.items()
} }
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self, direction='f'):
push_rules_key, _ = self.store.get_push_rules_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key(direction) yield self.sources["room"].get_current_key(direction)
@ -57,5 +60,6 @@ class EventSources(object):
account_data_key=( account_data_key=(
yield self.sources["account_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
push_rules_key=push_rules_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -115,6 +115,7 @@ class StreamToken(
"typing_key", "typing_key",
"receipt_key", "receipt_key",
"account_data_key", "account_data_key",
"push_rules_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -150,6 +151,7 @@ class StreamToken(
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -174,6 +176,11 @@ class StreamToken(
return StreamToken(**d) return StreamToken(**d)
StreamToken.START = StreamToken(
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1. """Tokens are positions between events. The token "s1" comes after event 1.