0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 03:21:54 +01:00

Port rest/ to Python 3 (#3823)

This commit is contained in:
Amber Brown 2018-09-12 20:41:31 +10:00 committed by GitHub
parent 8fd93b5eea
commit 02aa41809b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 113 additions and 100 deletions

1
changelog.d/3823.misc Normal file
View file

@ -0,0 +1 @@
rest/ is now ported to Python 3.

View file

@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet):
nonce = self.hs.get_secrets().token_hex(64) nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')}) return (200, {"nonce": nonce})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet):
key=self.hs.config.registration_shared_secret.encode(), key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) )
want_mac.update(nonce) want_mac.update(nonce.encode('utf8'))
want_mac.update(b"\x00") want_mac.update(b"\x00")
want_mac.update(username) want_mac.update(username)
want_mac.update(b"\x00") want_mac.update(b"\x00")
@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): if not hmac.compare_digest(
want_mac.encode('ascii'),
got_mac.encode('ascii')
):
raise SynapseError(403, "HMAC incorrect") raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication # Reuse the parts of RegisterRestServlet to reduce code duplication

View file

@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet):
is_guest = requester.is_guest is_guest = requester.is_guest
room_id = None room_id = None
if is_guest: if is_guest:
if "room_id" not in request.args: if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param") raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args: if b"room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args[b"room_id"][0].decode('ascii')
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if b"timeout" in request.args:
try: try:
timeout = int(request.args["timeout"][0]) timeout = int(request.args[b"timeout"][0])
except ValueError: except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.") raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream( chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),

View file

@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(

View file

@ -14,10 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from six.moves.urllib import parse as urlparse from six.moves import urllib
from canonicaljson import json from canonicaljson import json
from saml2 import BINDING_HTTP_POST, config from saml2 import BINDING_HTTP_POST, config
@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote( relay_state = "&RelayState=" + urllib.parse.quote(
login_submission["relay_state"]) login_submission["relay_state"])
result = { result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet):
(user_id, token) = yield handler.register_saml2(username) (user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava # Forward to the RelayState callback along with ava
if 'RelayState' in request.args: if 'RelayState' in request.args:
request.redirect(urllib.unquote( request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=authenticated&access_token=' + '?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' + token + '&user_id=' + user_id + '&ava=' +
@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet):
"user_id": user_id, "token": token, "user_id": user_id, "token": token,
"ava": saml2_auth.ava})) "ava": saml2_auth.ava}))
elif 'RelayState' in request.args: elif 'RelayState' in request.args:
request.redirect(urllib.unquote( request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) + request.args['RelayState'][0]) +
'?status=not_authenticated') '?status=not_authenticated')
finish_request(request) finish_request(request)
@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs) super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url.encode('ascii')
def on_GET(self, request): def on_GET(self, request):
args = request.args args = request.args
if "redirectUrl" not in args: if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth") return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({ client_redirect_url_param = urllib.parse.urlencode({
"redirectUrl": args["redirectUrl"][0] b"redirectUrl": args[b"redirectUrl"][0]
}) }).encode('ascii')
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" hs_redirect_url = (self.cas_service_url +
service_param = urllib.urlencode({ b"/_matrix/client/api/v1/login/cas/ticket")
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) service_param = urllib.parse.urlencode({
}) b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) }).encode('ascii')
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request) finish_request(request)
@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0] client_redirect_url = request.args[b"redirectUrl"][0]
http_client = self.hs.get_simple_http_client() http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": request.args["ticket"], "ticket": request.args[b"ticket"][0].decode('ascii'),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet):
finish_request(request) finish_request(request)
def add_login_token_to_redirect_url(self, url, token): def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url)) url_parts = list(urllib.parse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4])) query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token}) query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query) url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urlparse.urlunparse(url_parts) return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None user = None

View file

@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try: try:
priority_class = _priority_class_from_spec(spec) priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content, content,
) )
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
before = parse_string(request, "before") before = parse_string(request, "before")
if before: if before:
@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
) )
self.notify_user(user_id) self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
except RuleNotFoundException as e: except RuleNotFoundException as e:
raise SynapseError(400, e.message) raise SynapseError(400, str(e))
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
) )
if path[0] == '': if path[0] == b'':
defer.returnValue((200, rules)) defer.returnValue((200, rules))
elif path[0] == 'global': elif path[0] == b'global':
path = path[1:] path = [x.decode('ascii') for x in path[1:]]
result = _filter_ruleset_with_path(rules['global'], path) result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
else: else:
@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path): def _rule_spec_from_path(path):
if len(path) < 2: if len(path) < 2:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
if path[0] != 'pushrules': if path[0] != b'pushrules':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
scope = path[1] scope = path[1].decode('ascii')
path = path[2:] path = path[2:]
if scope != 'global': if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -203,13 +203,13 @@ def _rule_spec_from_path(path):
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
template = path[0] template = path[0].decode('ascii')
path = path[1:] path = path[1:]
if len(path) == 0 or len(path[0]) == 0: if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
rule_id = path[0] rule_id = path[0].decode('ascii')
spec = { spec = {
'scope': scope, 'scope': scope,
@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
path = path[1:] path = path[1:]
if len(path) > 0 and len(path[0]) > 0: if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0] spec['attr'] = path[0].decode('ascii')
return spec return spec

View file

@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet):
] ]
for p in pushers: for p in pushers:
for k, v in p.items(): for k, v in list(p.items()):
if k not in allowed_keys: if k not in allowed_keys:
del p[k] del p[k]
@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
profile_tag=content.get('profile_tag', ""), profile_tag=content.get('profile_tag', ""),
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message, raise SynapseError(400, "Config Error: " + str(pce),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data() self.notifier.on_new_replication_data()

View file

@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
if 'ts' in request.args and requester.app_service: if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event( event = yield self.event_creation_hander.create_and_send_nonmember_event(
@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if RoomID.is_valid(room_identifier): if RoomID.is_valid(room_identifier):
room_id = room_identifier room_id = room_identifier
try: try:
remote_room_hosts = request.args["server_name"] remote_room_hosts = [
x.decode('ascii') for x in request.args[b"server_name"]
]
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, "filter") filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None
@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
# picking the API shape for symmetry with /messages # picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter") filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None

View file

@ -89,7 +89,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
if "from" in request.args: if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'. # /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'. # Lets be helpful and whine if we see a 'from'.
raise SynapseError( raise SynapseError(

View file

@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop("access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop("access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields

View file

@ -88,5 +88,5 @@ class LocalKey(Resource):
) )
def getChild(self, name, request): def getChild(self, name, request):
if name == '': if name == b'':
return self return self

View file

@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey
class KeyApiV2Resource(Resource): class KeyApiV2Resource(Resource):
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) Resource.__init__(self)
self.putChild("server", LocalKey(hs)) self.putChild(b"server", LocalKey(hs))
self.putChild("query", RemoteKey(hs)) self.putChild(b"query", RemoteKey(hs))

View file

@ -103,7 +103,7 @@ class RemoteKey(Resource):
def async_render_GET(self, request): def async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
server, = request.postpath server, = request.postpath
query = {server: {}} query = {server.decode('ascii'): {}}
elif len(request.postpath) == 2: elif len(request.postpath) == 2:
server, key_id = request.postpath server, key_id = request.postpath
minimum_valid_until_ts = parse_integer( minimum_valid_until_ts = parse_integer(
@ -112,11 +112,12 @@ class RemoteKey(Resource):
arguments = {} arguments = {}
if minimum_valid_until_ts is not None: if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}} query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
else: else:
raise SynapseError( raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
) )
yield self.query_keys(request, query, query_remote_on_cache_miss=True) yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request): def render_POST(self, request):
@ -135,6 +136,7 @@ class RemoteKey(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False): def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if ( if (

View file

@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource):
# servers. # servers.
# TODO: A little crude here, we could do this better. # TODO: A little crude here, we could do this better.
filename = request.path.split('/')[-1] filename = request.path.decode('ascii').split('/')[-1]
# be paranoid # be paranoid
filename = re.sub("[^0-9A-z.-_]", "", filename) filename = re.sub("[^0-9A-z.-_]", "", filename)
@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource):
# select private. don't bother setting Expires as all our matrix # select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?) # clients are smart enough to be happy with Cache-Control (right?)
request.setHeader( request.setHeader(
"Cache-Control", "public,max-age=86400,s-maxage=86400" b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
) )
d = FileSender().beginFileTransfer(f, request) d = FileSender().beginFileTransfer(f, request)

View file

@ -15,9 +15,8 @@
import logging import logging
import os import os
import urllib
from six.moves.urllib import parse as urlparse from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -35,10 +34,15 @@ def parse_media_id(request):
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2] server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
server_name = server_name.decode('utf-8')
media_id = media_id.decode('utf8')
file_name = None file_name = None
if len(request.postpath) > 2: if len(request.postpath) > 2:
try: try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8") file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
return server_name, media_id, file_name return server_name, media_id, file_name
@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name):
file_size (int): Size in bytes of the media, if known. file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any. upload_name (str): The name of the requested file, if any.
""" """
def _quote(x):
return urllib.parse.quote(x.encode("utf-8"))
request.setHeader(b"Content-Type", media_type.encode("UTF-8")) request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name: if upload_name:
if is_ascii(upload_name): if is_ascii(upload_name):
request.setHeader( disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii")
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else: else:
request.setHeader( disposition = (
b"Content-Disposition", "inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii")
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")), request.setHeader(b"Content-Disposition", disposition)
),
)
# cache for at least a day. # cache for at least a day.
# XXX: we might want to turn this off for data we don't want to # XXX: we might want to turn this off for data we don't want to

View file

@ -47,12 +47,12 @@ class DownloadResource(Resource):
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
"Content-Security-Policy", b"Content-Security-Policy",
"default-src 'none';" b"default-src 'none';"
" script-src 'none';" b" script-src 'none';"
" plugin-types application/pdf;" b" plugin-types application/pdf;"
" style-src 'unsafe-inline';" b" style-src 'unsafe-inline';"
" object-src 'self';" b" object-src 'self';"
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:

View file

@ -20,7 +20,7 @@ import logging
import os import os
import shutil import shutil
from six import iteritems from six import PY3, iteritems
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
import twisted.internet.error import twisted.internet.error
@ -397,13 +397,13 @@ class MediaRepository(object):
yield finish() yield finish()
media_type = headers["Content-Type"][0] media_type = headers[b"Content-Type"][0].decode('ascii')
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition: if content_disposition:
_, params = cgi.parse_header(content_disposition[0],) _, params = cgi.parse_header(content_disposition[0].decode('ascii'),)
upload_name = None upload_name = None
# First check if there is a valid UTF-8 filename # First check if there is a valid UTF-8 filename
@ -419,9 +419,13 @@ class MediaRepository(object):
upload_name = upload_name_ascii upload_name = upload_name_ascii
if upload_name: if upload_name:
upload_name = urlparse.unquote(upload_name) if PY3:
upload_name = urlparse.unquote(upload_name)
else:
upload_name = urlparse.unquote(upload_name.encode('ascii'))
try: try:
upload_name = upload_name.decode("utf-8") if isinstance(upload_name, bytes):
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
upload_name = None upload_name = None
else: else:
@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource):
Resource.__init__(self) Resource.__init__(self)
media_repo = hs.get_media_repository() media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo)) self.putChild(b"upload", UploadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource( self.putChild(b"download", DownloadResource(hs, media_repo))
self.putChild(b"thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage, hs, media_repo, media_repo.media_storage,
)) ))
self.putChild("identicon", IdenticonResource()) self.putChild(b"identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource( self.putChild(b"preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage, hs, media_repo, media_repo.media_storage,
)) ))
self.putChild("config", MediaConfigResource(hs)) self.putChild(b"config", MediaConfigResource(hs))

View file

@ -261,7 +261,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og)) logger.debug("Calculated OG for %s as %s" % (url, og))
jsonog = json.dumps(og) jsonog = json.dumps(og).encode('utf8')
# store OG in history-aware DB cache # store OG in history-aware DB cache
yield self.store.store_url_cache( yield self.store.store_url_cache(
@ -301,20 +301,20 @@ class PreviewUrlResource(Resource):
logger.warn("Error downloading %s: %r", url, e) logger.warn("Error downloading %s: %r", url, e)
raise SynapseError( raise SynapseError(
500, "Failed to download content: %s" % ( 500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_type, e), traceback.format_exception_only(sys.exc_info()[0], e),
), ),
Codes.UNKNOWN, Codes.UNKNOWN,
) )
yield finish() yield finish()
try: try:
if "Content-Type" in headers: if b"Content-Type" in headers:
media_type = headers["Content-Type"][0] media_type = headers[b"Content-Type"][0].decode('ascii')
else: else:
media_type = "application/octet-stream" media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get(b"Content-Disposition", None)
if content_disposition: if content_disposition:
_, params = cgi.parse_header(content_disposition[0],) _, params = cgi.parse_header(content_disposition[0],)
download_name = None download_name = None