0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-29 07:58:19 +02:00

Add rest API & store for creating push rules

Also make unrecognised request error look more like synapse errors
because it makes it easier to throw them from within rest classes.
This commit is contained in:
David Baker 2015-01-22 17:37:12 +00:00
parent 5d5932d493
commit dc93860619
2 changed files with 391 additions and 0 deletions

195
synapse/rest/push_rule.py Normal file
View file

@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from base import RestServlet, client_path_pattern
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
import json
class PushRuleRestServlet(RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$")
def rule_spec_from_path(self, path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
if scope not in ['global', 'device']:
raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
spec = {
'scope' : scope,
'template': template,
'rule_id': rule_id
}
if device:
spec['device'] = device
return spec
def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj):
if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
conditions = req_obj['conditions']
for c in conditions:
if 'kind' not in c:
raise InvalidRuleException("Condition without 'kind'")
elif rule_template == 'room':
conditions = [{
'kind': 'event_match',
'key': 'room_id',
'pattern': rule_id
}]
elif rule_template == 'sender':
conditions = [{
'kind': 'event_match',
'key': 'user_id',
'pattern': rule_id
}]
elif rule_template == 'content':
if 'pattern' not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
conditions = [{
'kind': 'event_match',
'key': 'content.body',
'pattern': req_obj['pattern']
}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template))
if 'actions' not in req_obj:
raise InvalidRuleException("No actions found")
actions = req_obj['actions']
for a in actions:
if a in ['notify', 'dont-notify', 'coalesce']:
pass
elif isinstance(a, dict) and 'set_sound' in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
return (conditions, actions)
def priority_class_from_spec(self, spec):
map = {
'underride': 0,
'sender': 1,
'room': 2,
'content': 3,
'override': 4
}
if spec['template'] not in map.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
pc = map[spec['template']]
if spec['scope'] == 'device':
pc += 5
return pc
@defer.inlineCallbacks
def on_PUT(self, request):
spec = self.rule_spec_from_path(request.postpath)
try:
priority_class = self.priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
try:
(conditions, actions) = self.rule_tuple_from_request_object(
spec['template'],
spec['rule_id'],
content
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
after = request.args.get("after", None)
if after and len(after):
after = after[0]
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
rule_id=spec['rule_id'],
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after
)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException:
raise SynapseError(400, "before/after rule not found")
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
class InvalidRuleException(Exception):
pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server)

View file

@ -0,0 +1,196 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
import copy
import json
logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name):
sql = (
"SELECT "+",".join(PushRuleTable.fields)+
"FROM pushers "
"WHERE user_name = ?"
)
rows = yield self._execute(None, sql, user_name)
dicts = []
for r in rows:
d = {}
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(dicts)
@defer.inlineCallbacks
def add_push_rule(self, **kwargs):
vals = copy.copy(kwargs)
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
if 'id' in vals:
del vals['id']
if 'after' in kwargs or 'before' in kwargs:
ret = yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
**vals
)
defer.returnValue(ret)
else:
ret = yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
**vals
)
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = None
relative_to_rule = None
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
# get the priority of the rule we're inserting after/before
sql = (
"SELECT priority_class, priority FROM "+PushRuleTable.table_name+
" WHERE user_name = ? and rule_id = ?"
)
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res:
raise RuleNotFoundException()
(priority_class, base_rule_priority) = res[0]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
new_rule = copy.copy(kwargs)
if 'before' in new_rule:
del new_rule['before']
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
# check if the priority before/after is free
new_rule_priority = base_rule_priority
if after:
new_rule_priority -= 1
else:
new_rule_priority += 1
new_rule['priority'] = new_rule_priority
sql = (
"SELECT COUNT(*) FROM "+PushRuleTable.table_name+
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
txn.execute(sql, (user_name, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
# if there are conflicting rules, bump everything
if num_conflicting:
sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
if after:
sql += "-1"
else:
sql += "+1"
sql += " WHERE user_name = ? AND priority_class = ? AND priority "
if after:
sql += "<= ?"
else:
sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name, priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM "+PushRuleTable.table_name+
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_name, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
new_prio = 0
if how_many > 0:
new_prio = highest_prio + 1
# and insert the new rule
new_rule = copy.copy(kwargs)
if 'id' in new_rule:
del new_rule['id']
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
class RuleNotFoundException(Exception):
pass
class InconsistentRuleException(Exception):
pass
class PushRuleTable(Table):
table_name = "push_rules"
fields = [
"id",
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)