mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:13:50 +01:00
Combine the request wrappers in rest/media/v1 and http/server into a single wrapper decorator
This commit is contained in:
parent
1ebff9736b
commit
1967650bc4
7 changed files with 140 additions and 198 deletions
|
@ -57,9 +57,18 @@ _next_request_id = 0
|
|||
def request_handler(request_handler):
|
||||
"""Wraps a method that acts as a request handler with the necessary logging
|
||||
and exception handling.
|
||||
The method must have a signature of "handle_foo(self, request)".
|
||||
The argument "self" must have "version_string" and "clock" attributes.
|
||||
The argument "request" must be a twisted HTTP request.
|
||||
|
||||
The method must have a signature of "handle_foo(self, request)". The
|
||||
argument "self" must have "version_string" and "clock" attributes. The
|
||||
argument "request" must be a twisted HTTP request.
|
||||
|
||||
The method must return a deferred. If the deferred succeeds we assume that
|
||||
a response has been sent. If the deferred fails with a SynapseError we use
|
||||
it to send a JSON response with the appropriate HTTP reponse code. If the
|
||||
deferred fails with any other type of error we send a 500 reponse.
|
||||
|
||||
We insert a unique request-id into the logging context for this request and
|
||||
log the response and duration for this request.
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -23,6 +23,61 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_integer(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return int(request.args[name][0])
|
||||
except:
|
||||
message = "Query parameter %r must be an integer" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing integer query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def parse_boolean(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
}[request.args[name][0]]
|
||||
except:
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of"
|
||||
" ['true', 'false']"
|
||||
) % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing boolean query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def parse_string(request, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
if name in request.args:
|
||||
value = request.args[name][0]
|
||||
if allowed_values is not None and value not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name, ", ".join(repr(v) for v in allowed_values)
|
||||
)
|
||||
raise SynapseError(message)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
if required:
|
||||
message = "Missing %s query parameter %r" % (param_type, name)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
class RestServlet(object):
|
||||
|
||||
""" A Synapse REST Servlet.
|
||||
|
@ -56,58 +111,3 @@ class RestServlet(object):
|
|||
http_server.register_path(method, pattern, method_handler)
|
||||
else:
|
||||
raise NotImplementedError("RestServlet must register something.")
|
||||
|
||||
@staticmethod
|
||||
def parse_integer(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return int(request.args[name][0])
|
||||
except:
|
||||
message = "Query parameter %r must be an integer" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing integer query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def parse_boolean(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
}[request.args[name][0]]
|
||||
except:
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of"
|
||||
" ['true', 'false']"
|
||||
) % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing boolean query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def parse_string(request, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
if name in request.args:
|
||||
value = request.args[name][0]
|
||||
if allowed_values is not None and value not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name, ", ".join(repr(v) for v in allowed_values)
|
||||
)
|
||||
raise SynapseError(message)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
if required:
|
||||
message = "Missing %s query parameter %r" % (param_type, name)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_string, parse_integer, parse_boolean
|
||||
)
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.types import StreamToken
|
||||
from synapse.events.utils import (
|
||||
|
@ -87,20 +89,20 @@ class SyncRestServlet(RestServlet):
|
|||
def on_GET(self, request):
|
||||
user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
timeout = self.parse_integer(request, "timeout", default=0)
|
||||
limit = self.parse_integer(request, "limit", required=True)
|
||||
gap = self.parse_boolean(request, "gap", default=True)
|
||||
sort = self.parse_string(
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
gap = parse_boolean(request, "gap", default=True)
|
||||
sort = parse_string(
|
||||
request, "sort", default="timeline,asc",
|
||||
allowed_values=self.ALLOWED_SORT
|
||||
)
|
||||
since = self.parse_string(request, "since")
|
||||
set_presence = self.parse_string(
|
||||
since = parse_string(request, "since")
|
||||
set_presence = parse_string(
|
||||
request, "set_presence", default="online",
|
||||
allowed_values=self.ALLOWED_PRESENCE
|
||||
)
|
||||
backfill = self.parse_boolean(request, "backfill", default=False)
|
||||
filter_id = self.parse_string(request, "filter", default=None)
|
||||
backfill = parse_boolean(request, "backfill", default=False)
|
||||
filter_id = parse_string(request, "filter", default=None)
|
||||
|
||||
logger.info(
|
||||
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
|
||||
|
|
|
@ -18,7 +18,7 @@ from .thumbnailer import Thumbnailer
|
|||
from synapse.http.server import respond_with_json
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, CodeMessageException, cs_error, Codes, SynapseError
|
||||
cs_error, Codes, SynapseError
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -32,6 +32,18 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Invalid media id token %r" % (request.postpath,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
class BaseMediaResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
|
@ -47,72 +59,6 @@ class BaseMediaResource(Resource):
|
|||
self.filepaths = filepaths
|
||||
self.downloads = {}
|
||||
|
||||
@staticmethod
|
||||
def catch_errors(request_handler):
|
||||
@defer.inlineCallbacks
|
||||
def wrapped_request_handler(self, request):
|
||||
try:
|
||||
yield request_handler(self, request)
|
||||
except CodeMessageException as e:
|
||||
logger.info("Responding with error: %r", e)
|
||||
respond_with_json(
|
||||
request, e.code, cs_exception(e), send_cors=True
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r",
|
||||
request_handler.__module__,
|
||||
request_handler.__name__,
|
||||
self,
|
||||
)
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
return wrapped_request_handler
|
||||
|
||||
@staticmethod
|
||||
def _parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Invalid media id token %r" % (request.postpath,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_integer(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return int(request.args[arg_name][0])
|
||||
else:
|
||||
return int(request.args.get(arg_name, [default])[0])
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing integer argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_string(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return request.args[arg_name][0]
|
||||
else:
|
||||
return request.args.get(arg_name, [default])[0]
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing string argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
def _respond_404(self, request):
|
||||
respond_with_json(
|
||||
request, 404,
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
from .base_resource import BaseMediaResource, parse_media_id
|
||||
from synapse.http.server import request_handler
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
@ -28,15 +29,10 @@ class DownloadResource(BaseMediaResource):
|
|||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
except:
|
||||
self._respond_404(request)
|
||||
return
|
||||
|
||||
server_name, media_id = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id)
|
||||
else:
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
from .base_resource import BaseMediaResource, parse_media_id
|
||||
from synapse.http.servlet import parse_string, parse_integer
|
||||
from synapse.http.server import request_handler
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
@ -31,14 +33,14 @@ class ThumbnailResource(BaseMediaResource):
|
|||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id = self._parse_media_id(request)
|
||||
width = self._parse_integer(request, "width")
|
||||
height = self._parse_integer(request, "height")
|
||||
method = self._parse_string(request, "method", "scale")
|
||||
m_type = self._parse_string(request, "type", "image/png")
|
||||
server_name, media_id = parse_media_id(request)
|
||||
width = parse_integer(request, "width")
|
||||
height = parse_integer(request, "height")
|
||||
method = parse_string(request, "method", "scale")
|
||||
m_type = parse_string(request, "type", "image/png")
|
||||
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_thumbnail(
|
||||
|
|
|
@ -13,12 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.http.server import respond_with_json, request_handler
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException
|
||||
)
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
@ -69,9 +67,9 @@ class UploadResource(BaseMediaResource):
|
|||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
try:
|
||||
auth_user, client = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
|
@ -108,14 +106,3 @@ class UploadResource(BaseMediaResource):
|
|||
respond_with_json(
|
||||
request, 200, {"content_uri": content_uri}, send_cors=True
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json(request, e.code, cs_exception(e), send_cors=True)
|
||||
except:
|
||||
logger.exception("Failed to store file")
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue