Fix non-ASCII pushrules (#4248)

This commit is contained in:
Amber Brown 2018-12-04 22:44:02 +11:00 committed by Richard van der Hoff
parent dd27e47b5c
commit fd96dd75a3
2 changed files with 24 additions and 12 deletions

1
changelog.d/4165.bugfix Normal file
View file

@ -0,0 +1 @@
Pushrules can now again be made with non-ASCII rule IDs.

View file

@ -42,7 +42,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath])
try: try:
priority_class = _priority_class_from_spec(spec) priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e: except InvalidRuleException as e:
@ -103,7 +103,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request): def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath])
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -134,7 +134,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules = format_push_rules_for_user(requester.user, rules) rules = format_push_rules_for_user(requester.user, rules)
path = request.postpath[1:] path = [x.decode('utf8') for x in request.postpath][1:]
if path == []: if path == []:
# we're a reference impl: pedantry is our job. # we're a reference impl: pedantry is our job.
@ -142,11 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
) )
if path[0] == b'': if path[0] == '':
defer.returnValue((200, rules)) defer.returnValue((200, rules))
elif path[0] == b'global': elif path[0] == 'global':
path = [x.decode('ascii') for x in path[1:]] result = _filter_ruleset_with_path(rules['global'], path[1:])
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -190,12 +189,24 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path): def _rule_spec_from_path(path):
"""Turn a sequence of path components into a rule spec
Args:
path (sequence[unicode]): the URL path components.
Returns:
dict: rule spec dict, containing scope/template/rule_id entries,
and possibly attr.
Raises:
UnrecognizedRequestError if the path components cannot be parsed.
"""
if len(path) < 2: if len(path) < 2:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
if path[0] != b'pushrules': if path[0] != 'pushrules':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
scope = path[1].decode('ascii') scope = path[1]
path = path[2:] path = path[2:]
if scope != 'global': if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -203,13 +214,13 @@ def _rule_spec_from_path(path):
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
template = path[0].decode('ascii') template = path[0]
path = path[1:] path = path[1:]
if len(path) == 0 or len(path[0]) == 0: if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
rule_id = path[0].decode('ascii') rule_id = path[0]
spec = { spec = {
'scope': scope, 'scope': scope,
@ -220,7 +231,7 @@ def _rule_spec_from_path(path):
path = path[1:] path = path[1:]
if len(path) > 0 and len(path[0]) > 0: if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0].decode('ascii') spec['attr'] = path[0]
return spec return spec