mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:43:52 +01:00
Add API for getting/setting enabled-ness of push rules.
This commit is contained in:
parent
944003021b
commit
1959088156
3 changed files with 77 additions and 9 deletions
|
@ -37,7 +37,7 @@ def make_base_rules(user, kind):
|
|||
|
||||
for r in rules:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
|
||||
r['default'] = True
|
||||
r['default'] = True # Deprecated, left for backwards compat
|
||||
|
||||
return rules
|
||||
|
||||
|
@ -45,7 +45,7 @@ def make_base_rules(user, kind):
|
|||
def make_base_content_rules(user):
|
||||
return [
|
||||
{
|
||||
'rule_id': '.m.rule.contains_user_name',
|
||||
'rule_id': 'global/content/.m.rule.contains_user_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
|
@ -67,7 +67,7 @@ def make_base_content_rules(user):
|
|||
def make_base_override_rules():
|
||||
return [
|
||||
{
|
||||
'rule_id': '.m.rule.contains_display_name',
|
||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'contains_display_name'
|
||||
|
@ -82,7 +82,7 @@ def make_base_override_rules():
|
|||
]
|
||||
},
|
||||
{
|
||||
'rule_id': '.m.rule.room_two_members',
|
||||
'rule_id': 'global/override/.m.rule.room_two_members',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'room_member_count',
|
||||
|
|
|
@ -50,6 +50,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
content = _parse_json(request)
|
||||
|
||||
if 'attr' in spec:
|
||||
self.set_rule_attr(user.to_string(), spec, content)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
try:
|
||||
(conditions, actions) = _rule_tuple_from_request_object(
|
||||
spec['template'],
|
||||
|
@ -124,6 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
|
||||
|
||||
enabled_map = yield self.hs.get_datastore().\
|
||||
get_push_rules_enabled_for_user_name(user.to_string())
|
||||
|
||||
for r in ruleslist:
|
||||
rulearray = None
|
||||
|
||||
|
@ -149,6 +156,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
template_rule = _rule_to_template(r)
|
||||
if template_rule:
|
||||
template_rule['enabled'] = True
|
||||
if r['rule_id'] in enabled_map:
|
||||
template_rule['enabled'] = enabled_map[r['rule_id']]
|
||||
rulearray.append(template_rule)
|
||||
|
||||
path = request.postpath[1:]
|
||||
|
@ -189,6 +199,24 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def set_rule_attr(self, user_name, spec, val):
|
||||
if spec['attr'] == 'enabled':
|
||||
if not isinstance(val, bool):
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
self.hs.get_datastore().set_push_rule_enabled(
|
||||
user_name, namespaced_rule_id, val
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
|
||||
if attr == 'enabled':
|
||||
return self.hs.get_datastore().get_push_rule_enabled_by_user_name_rule_id(
|
||||
user_name, namespaced_rule_id
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def _rule_spec_from_path(path):
|
||||
if len(path) < 2:
|
||||
|
@ -226,6 +254,12 @@ def _rule_spec_from_path(path):
|
|||
}
|
||||
if device:
|
||||
spec['profile_tag'] = device
|
||||
|
||||
path = path[1:]
|
||||
|
||||
if len(path) > 0 and len(path[0]) > 0:
|
||||
spec['attr'] = path[0]
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
|
@ -319,10 +353,23 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
if path[0] == '':
|
||||
return ruleset[template_kind]
|
||||
rule_id = path[0]
|
||||
|
||||
the_rule = None
|
||||
for r in ruleset[template_kind]:
|
||||
if r['rule_id'] == rule_id:
|
||||
return r
|
||||
raise NotFoundError
|
||||
the_rule = r
|
||||
if the_rule is None:
|
||||
raise NotFoundError
|
||||
|
||||
path = path[1:]
|
||||
if len(path) == 0:
|
||||
return the_rule
|
||||
|
||||
attr = path[0]
|
||||
if attr in the_rule:
|
||||
return the_rule[attr]
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
|
||||
def _priority_class_from_spec(spec):
|
||||
|
@ -399,9 +446,6 @@ class InvalidRuleException(Exception):
|
|||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise SynapseError(400, "Content must be a JSON object.",
|
||||
errcode=Codes.NOT_JSON)
|
||||
return content
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
|
|
@ -56,6 +56,17 @@ class PushRuleStore(SQLBaseStore):
|
|||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_push_rule_enabled_by_user_name_rule_id(self, user_name, rule_id):
|
||||
results = yield self._simple_select_list(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
['enabled']
|
||||
)
|
||||
if len(results) == 0:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(results[0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_push_rule(self, before, after, **kwargs):
|
||||
vals = copy.copy(kwargs)
|
||||
|
@ -204,6 +215,19 @@ class PushRuleStore(SQLBaseStore):
|
|||
{'user_name': user_name, 'rule_id': rule_id}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||
if enabled:
|
||||
yield self._simple_delete_one(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id}
|
||||
)
|
||||
else:
|
||||
yield self._simple_upsert(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'enabled': False}
|
||||
)
|
||||
|
||||
class RuleNotFoundException(Exception):
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue