Add a parse_json_object function

to deduplicate all the copy+pasted _parse_json functions. Also document
the parse_.* functions.
This commit is contained in:
Mark Haines 2016-03-09 11:26:26 +00:00 committed by review.rocks
parent 158a322e82
commit b7dbe5147a
11 changed files with 97 additions and 121 deletions

View file

@ -15,14 +15,27 @@
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False): def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: An int value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
if name in request.args: if name in request.args:
try: try:
return int(request.args[name][0]) return int(request.args[name][0])
@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing integer query parameter %r" % (name,) message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_boolean(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: A bool value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
if name in request.args: if name in request.args:
try: try:
return { return {
@ -53,30 +79,64 @@ def parse_boolean(request, name, default=None, required=False):
else: else:
if required: if required:
message = "Missing boolean query parameter %r" % (name,) message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string,
or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
if name in request.args: if name in request.args:
value = request.args[name][0] value = request.args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
) )
raise SynapseError(message) raise SynapseError(400, message)
else: else:
return value return value
else: else:
if required: if required:
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
return default return default
def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:raises
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View file

@ -18,9 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
@ -45,7 +46,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
content = _parse_json(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
@ -135,14 +136,3 @@ class ClientDirectoryServer(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -79,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
login_submission = _parse_json(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled: if not self.password_enabled:
@ -400,18 +401,6 @@ class CasTicketServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(
400, "Content must be a JSON object.", errcode=Codes.BAD_JSON
)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled: if hs.config.saml2_enabled:

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
) )
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
@ -25,8 +25,7 @@ from synapse.storage.push_rule import (
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet): class PushRuleRestServlet(ClientV1RestServlet):
@ -52,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
content = _parse_json(request) content = parse_json_object_from_request(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -341,14 +340,5 @@ class InvalidRuleException(Exception):
pass pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server) PushRuleRestServlet(hs).register(http_server)

View file

@ -17,9 +17,10 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ class PusherRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
content = _parse_json(request) content = parse_json_object_from_request(request)
pusher_pool = self.hs.get_pusherpool() pusher_pool = self.hs.get_pusherpool()
@ -92,17 +93,5 @@ class PusherRestServlet(ClientV1RestServlet):
return 200, {} return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server) PusherRestServlet(hs).register(http_server)

View file

@ -20,12 +20,12 @@ from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from hashlib import sha1 from hashlib import sha1
import hmac import hmac
import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
register_json = _parse_json(request) register_json = parse_json_object_from_request(request)
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
) )
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import simplejson as json import simplejson as json
import logging import logging
@ -137,7 +138,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
event_dict = { event_dict = {
"type": event_type, "type": event_type,
@ -179,7 +180,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event( event = yield msg_handler.create_and_send_nonmember_event(
@ -229,7 +230,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
) )
try: try:
content = _parse_json(request) content = parse_json_object_from_request(request)
except: except:
# Turns out we used to ignore the body entirely, and some clients # Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies. # cheekily send invalid bodies.
@ -433,7 +434,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
try: try:
content = _parse_json(request) content = parse_json_object_from_request(request)
except: except:
# Turns out we used to ignore the body entirely, and some clients # Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies. # cheekily send invalid bodies.
@ -500,7 +501,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event( event = yield msg_handler.create_and_send_nonmember_event(
@ -548,7 +549,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
content = _parse_json(request) content = parse_json_object_from_request(request)
typing_handler = self.handlers.typing_notification_handler typing_handler = self.handlers.typing_notification_handler
@ -580,7 +581,7 @@ class SearchRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
@ -592,17 +593,6 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_txn_path(servlet, regex_string, http_server, with_get=False): def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path. """Registers a transaction-based path.

View file

@ -17,11 +17,9 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)):
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns return patterns
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View file

@ -17,10 +17,10 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
@ -41,7 +41,7 @@ class PasswordRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds') threePidCreds = body.get('threePidCreds')
threePidCreds = body.get('three_pid_creds', threePidCreds) threePidCreds = body.get('three_pid_creds', threePidCreds)

View file

@ -17,9 +17,9 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
import logging import logging
import hmac import hmac
@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet):
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us. # in sessions. Pull out the username/password provided to us.
@ -236,7 +236,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def onEmailTokenRequest(self, request): def onEmailTokenRequest(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt'] required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = [] absent = []

View file

@ -16,9 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns, parse_json_dict_from_request from ._base import client_v2_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_dict_from_request(request) body = parse_json_object_from_request(request)
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler auth_handler = self.hs.get_handlers().auth_handler