mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-07 16:03:52 +01:00
Port rest/ to Python 3 (#3823)
This commit is contained in:
parent
8fd93b5eea
commit
02aa41809b
18 changed files with 113 additions and 100 deletions
1
changelog.d/3823.misc
Normal file
1
changelog.d/3823.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
rest/ is now ported to Python 3.
|
|
@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
nonce = self.hs.get_secrets().token_hex(64)
|
nonce = self.hs.get_secrets().token_hex(64)
|
||||||
self.nonces[nonce] = int(self.reactor.seconds())
|
self.nonces[nonce] = int(self.reactor.seconds())
|
||||||
return (200, {"nonce": nonce.encode('ascii')})
|
return (200, {"nonce": nonce})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
key=self.hs.config.registration_shared_secret.encode(),
|
key=self.hs.config.registration_shared_secret.encode(),
|
||||||
digestmod=hashlib.sha1,
|
digestmod=hashlib.sha1,
|
||||||
)
|
)
|
||||||
want_mac.update(nonce)
|
want_mac.update(nonce.encode('utf8'))
|
||||||
want_mac.update(b"\x00")
|
want_mac.update(b"\x00")
|
||||||
want_mac.update(username)
|
want_mac.update(username)
|
||||||
want_mac.update(b"\x00")
|
want_mac.update(b"\x00")
|
||||||
|
@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
want_mac.update(b"admin" if admin else b"notadmin")
|
want_mac.update(b"admin" if admin else b"notadmin")
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
|
if not hmac.compare_digest(
|
||||||
|
want_mac.encode('ascii'),
|
||||||
|
got_mac.encode('ascii')
|
||||||
|
):
|
||||||
raise SynapseError(403, "HMAC incorrect")
|
raise SynapseError(403, "HMAC incorrect")
|
||||||
|
|
||||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||||
|
|
|
@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||||
is_guest = requester.is_guest
|
is_guest = requester.is_guest
|
||||||
room_id = None
|
room_id = None
|
||||||
if is_guest:
|
if is_guest:
|
||||||
if "room_id" not in request.args:
|
if b"room_id" not in request.args:
|
||||||
raise SynapseError(400, "Guest users must specify room_id param")
|
raise SynapseError(400, "Guest users must specify room_id param")
|
||||||
if "room_id" in request.args:
|
if b"room_id" in request.args:
|
||||||
room_id = request.args["room_id"][0]
|
room_id = request.args[b"room_id"][0].decode('ascii')
|
||||||
|
|
||||||
pagin_config = PaginationConfig.from_request(request)
|
pagin_config = PaginationConfig.from_request(request)
|
||||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||||
if "timeout" in request.args:
|
if b"timeout" in request.args:
|
||||||
try:
|
try:
|
||||||
timeout = int(request.args["timeout"][0])
|
timeout = int(request.args[b"timeout"][0])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||||
|
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
|
|
||||||
chunk = yield self.event_stream_handler.get_stream(
|
chunk = yield self.event_stream_handler.get_stream(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
|
|
|
@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
include_archived = parse_boolean(request, "archived", default=False)
|
include_archived = parse_boolean(request, "archived", default=False)
|
||||||
content = yield self.initial_sync_handler.snapshot_all_rooms(
|
content = yield self.initial_sync_handler.snapshot_all_rooms(
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from six.moves.urllib import parse as urlparse
|
from six.moves import urllib
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from saml2 import BINDING_HTTP_POST, config
|
from saml2 import BINDING_HTTP_POST, config
|
||||||
|
@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
LoginRestServlet.SAML2_TYPE):
|
LoginRestServlet.SAML2_TYPE):
|
||||||
relay_state = ""
|
relay_state = ""
|
||||||
if "relay_state" in login_submission:
|
if "relay_state" in login_submission:
|
||||||
relay_state = "&RelayState=" + urllib.quote(
|
relay_state = "&RelayState=" + urllib.parse.quote(
|
||||||
login_submission["relay_state"])
|
login_submission["relay_state"])
|
||||||
result = {
|
result = {
|
||||||
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
||||||
|
@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet):
|
||||||
(user_id, token) = yield handler.register_saml2(username)
|
(user_id, token) = yield handler.register_saml2(username)
|
||||||
# Forward to the RelayState callback along with ava
|
# Forward to the RelayState callback along with ava
|
||||||
if 'RelayState' in request.args:
|
if 'RelayState' in request.args:
|
||||||
request.redirect(urllib.unquote(
|
request.redirect(urllib.parse.unquote(
|
||||||
request.args['RelayState'][0]) +
|
request.args['RelayState'][0]) +
|
||||||
'?status=authenticated&access_token=' +
|
'?status=authenticated&access_token=' +
|
||||||
token + '&user_id=' + user_id + '&ava=' +
|
token + '&user_id=' + user_id + '&ava=' +
|
||||||
|
@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet):
|
||||||
"user_id": user_id, "token": token,
|
"user_id": user_id, "token": token,
|
||||||
"ava": saml2_auth.ava}))
|
"ava": saml2_auth.ava}))
|
||||||
elif 'RelayState' in request.args:
|
elif 'RelayState' in request.args:
|
||||||
request.redirect(urllib.unquote(
|
request.redirect(urllib.parse.unquote(
|
||||||
request.args['RelayState'][0]) +
|
request.args['RelayState'][0]) +
|
||||||
'?status=not_authenticated')
|
'?status=not_authenticated')
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CasRedirectServlet, self).__init__(hs)
|
super(CasRedirectServlet, self).__init__(hs)
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||||
self.cas_service_url = hs.config.cas_service_url
|
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
args = request.args
|
args = request.args
|
||||||
if "redirectUrl" not in args:
|
if b"redirectUrl" not in args:
|
||||||
return (400, "Redirect URL not specified for CAS auth")
|
return (400, "Redirect URL not specified for CAS auth")
|
||||||
client_redirect_url_param = urllib.urlencode({
|
client_redirect_url_param = urllib.parse.urlencode({
|
||||||
"redirectUrl": args["redirectUrl"][0]
|
b"redirectUrl": args[b"redirectUrl"][0]
|
||||||
})
|
}).encode('ascii')
|
||||||
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
|
hs_redirect_url = (self.cas_service_url +
|
||||||
service_param = urllib.urlencode({
|
b"/_matrix/client/api/v1/login/cas/ticket")
|
||||||
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
service_param = urllib.parse.urlencode({
|
||||||
})
|
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||||
request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
|
}).encode('ascii')
|
||||||
|
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
client_redirect_url = request.args["redirectUrl"][0]
|
client_redirect_url = request.args[b"redirectUrl"][0]
|
||||||
http_client = self.hs.get_simple_http_client()
|
http_client = self.hs.get_simple_http_client()
|
||||||
uri = self.cas_server_url + "/proxyValidate"
|
uri = self.cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
"ticket": request.args["ticket"],
|
"ticket": request.args[b"ticket"][0].decode('ascii'),
|
||||||
"service": self.cas_service_url
|
"service": self.cas_service_url
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
|
@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
def add_login_token_to_redirect_url(self, url, token):
|
def add_login_token_to_redirect_url(self, url, token):
|
||||||
url_parts = list(urlparse.urlparse(url))
|
url_parts = list(urllib.parse.urlparse(url))
|
||||||
query = dict(urlparse.parse_qsl(url_parts[4]))
|
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||||
query.update({"loginToken": token})
|
query.update({"loginToken": token})
|
||||||
url_parts[4] = urllib.urlencode(query)
|
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
|
||||||
return urlparse.urlunparse(url_parts)
|
return urllib.parse.urlunparse(url_parts)
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
def parse_cas_response(self, cas_response_body):
|
||||||
user = None
|
user = None
|
||||||
|
|
|
@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
try:
|
try:
|
||||||
priority_class = _priority_class_from_spec(spec)
|
priority_class = _priority_class_from_spec(spec)
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
content,
|
content,
|
||||||
)
|
)
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
before = parse_string(request, "before")
|
before = parse_string(request, "before")
|
||||||
if before:
|
if before:
|
||||||
|
@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
self.notify_user(user_id)
|
self.notify_user(user_id)
|
||||||
except InconsistentRuleException as e:
|
except InconsistentRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, str(e))
|
||||||
except RuleNotFoundException as e:
|
except RuleNotFoundException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
if path[0] == '':
|
if path[0] == b'':
|
||||||
defer.returnValue((200, rules))
|
defer.returnValue((200, rules))
|
||||||
elif path[0] == 'global':
|
elif path[0] == b'global':
|
||||||
path = path[1:]
|
path = [x.decode('ascii') for x in path[1:]]
|
||||||
result = _filter_ruleset_with_path(rules['global'], path)
|
result = _filter_ruleset_with_path(rules['global'], path)
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
else:
|
else:
|
||||||
|
@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
def _rule_spec_from_path(path):
|
def _rule_spec_from_path(path):
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
if path[0] != 'pushrules':
|
if path[0] != b'pushrules':
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
|
|
||||||
scope = path[1]
|
scope = path[1].decode('ascii')
|
||||||
path = path[2:]
|
path = path[2:]
|
||||||
if scope != 'global':
|
if scope != 'global':
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
|
@ -203,13 +203,13 @@ def _rule_spec_from_path(path):
|
||||||
if len(path) == 0:
|
if len(path) == 0:
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
|
|
||||||
template = path[0]
|
template = path[0].decode('ascii')
|
||||||
path = path[1:]
|
path = path[1:]
|
||||||
|
|
||||||
if len(path) == 0 or len(path[0]) == 0:
|
if len(path) == 0 or len(path[0]) == 0:
|
||||||
raise UnrecognizedRequestError()
|
raise UnrecognizedRequestError()
|
||||||
|
|
||||||
rule_id = path[0]
|
rule_id = path[0].decode('ascii')
|
||||||
|
|
||||||
spec = {
|
spec = {
|
||||||
'scope': scope,
|
'scope': scope,
|
||||||
|
@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
|
||||||
path = path[1:]
|
path = path[1:]
|
||||||
|
|
||||||
if len(path) > 0 and len(path[0]) > 0:
|
if len(path) > 0 and len(path[0]) > 0:
|
||||||
spec['attr'] = path[0]
|
spec['attr'] = path[0].decode('ascii')
|
||||||
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet):
|
||||||
]
|
]
|
||||||
|
|
||||||
for p in pushers:
|
for p in pushers:
|
||||||
for k, v in p.items():
|
for k, v in list(p.items()):
|
||||||
if k not in allowed_keys:
|
if k not in allowed_keys:
|
||||||
del p[k]
|
del p[k]
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
|
||||||
profile_tag=content.get('profile_tag', ""),
|
profile_tag=content.get('profile_tag', ""),
|
||||||
)
|
)
|
||||||
except PusherConfigException as pce:
|
except PusherConfigException as pce:
|
||||||
raise SynapseError(400, "Config Error: " + pce.message,
|
raise SynapseError(400, "Config Error: " + str(pce),
|
||||||
errcode=Codes.MISSING_PARAM)
|
errcode=Codes.MISSING_PARAM)
|
||||||
|
|
||||||
self.notifier.on_new_replication_data()
|
self.notifier.on_new_replication_data()
|
||||||
|
|
|
@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
"sender": requester.user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'ts' in request.args and requester.app_service:
|
if b'ts' in request.args and requester.app_service:
|
||||||
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||||
|
|
||||||
event = yield self.event_creation_hander.create_and_send_nonmember_event(
|
event = yield self.event_creation_hander.create_and_send_nonmember_event(
|
||||||
|
@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
if RoomID.is_valid(room_identifier):
|
if RoomID.is_valid(room_identifier):
|
||||||
room_id = room_identifier
|
room_id = room_identifier
|
||||||
try:
|
try:
|
||||||
remote_room_hosts = request.args["server_name"]
|
remote_room_hosts = [
|
||||||
|
x.decode('ascii') for x in request.args[b"server_name"]
|
||||||
|
]
|
||||||
except Exception:
|
except Exception:
|
||||||
remote_room_hosts = None
|
remote_room_hosts = None
|
||||||
elif RoomAlias.is_valid(room_identifier):
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
|
@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
pagination_config = PaginationConfig.from_request(
|
pagination_config = PaginationConfig.from_request(
|
||||||
request, default_limit=10,
|
request, default_limit=10,
|
||||||
)
|
)
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
filter_bytes = parse_string(request, "filter")
|
filter_bytes = parse_string(request, b"filter", encoding=None)
|
||||||
if filter_bytes:
|
if filter_bytes:
|
||||||
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
|
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
|
||||||
event_filter = Filter(json.loads(filter_json))
|
event_filter = Filter(json.loads(filter_json))
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
|
||||||
# picking the API shape for symmetry with /messages
|
# picking the API shape for symmetry with /messages
|
||||||
filter_bytes = parse_string(request, "filter")
|
filter_bytes = parse_string(request, "filter")
|
||||||
if filter_bytes:
|
if filter_bytes:
|
||||||
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
|
filter_json = urlparse.unquote(filter_bytes)
|
||||||
event_filter = Filter(json.loads(filter_json))
|
event_filter = Filter(json.loads(filter_json))
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
|
@ -89,7 +89,7 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
if "from" in request.args:
|
if b"from" in request.args:
|
||||||
# /events used to use 'from', but /sync uses 'since'.
|
# /events used to use 'from', but /sync uses 'since'.
|
||||||
# Lets be helpful and whine if we see a 'from'.
|
# Lets be helpful and whine if we see a 'from'.
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
|
|
@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet):
|
||||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
fields = request.args
|
fields = request.args
|
||||||
fields.pop("access_token", None)
|
fields.pop(b"access_token", None)
|
||||||
|
|
||||||
results = yield self.appservice_handler.query_3pe(
|
results = yield self.appservice_handler.query_3pe(
|
||||||
ThirdPartyEntityKind.USER, protocol, fields
|
ThirdPartyEntityKind.USER, protocol, fields
|
||||||
|
@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet):
|
||||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
fields = request.args
|
fields = request.args
|
||||||
fields.pop("access_token", None)
|
fields.pop(b"access_token", None)
|
||||||
|
|
||||||
results = yield self.appservice_handler.query_3pe(
|
results = yield self.appservice_handler.query_3pe(
|
||||||
ThirdPartyEntityKind.LOCATION, protocol, fields
|
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||||
|
|
|
@ -88,5 +88,5 @@ class LocalKey(Resource):
|
||||||
)
|
)
|
||||||
|
|
||||||
def getChild(self, name, request):
|
def getChild(self, name, request):
|
||||||
if name == '':
|
if name == b'':
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -22,5 +22,5 @@ from .remote_key_resource import RemoteKey
|
||||||
class KeyApiV2Resource(Resource):
|
class KeyApiV2Resource(Resource):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
self.putChild("server", LocalKey(hs))
|
self.putChild(b"server", LocalKey(hs))
|
||||||
self.putChild("query", RemoteKey(hs))
|
self.putChild(b"query", RemoteKey(hs))
|
||||||
|
|
|
@ -103,7 +103,7 @@ class RemoteKey(Resource):
|
||||||
def async_render_GET(self, request):
|
def async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
server, = request.postpath
|
server, = request.postpath
|
||||||
query = {server: {}}
|
query = {server.decode('ascii'): {}}
|
||||||
elif len(request.postpath) == 2:
|
elif len(request.postpath) == 2:
|
||||||
server, key_id = request.postpath
|
server, key_id = request.postpath
|
||||||
minimum_valid_until_ts = parse_integer(
|
minimum_valid_until_ts = parse_integer(
|
||||||
|
@ -112,11 +112,12 @@ class RemoteKey(Resource):
|
||||||
arguments = {}
|
arguments = {}
|
||||||
if minimum_valid_until_ts is not None:
|
if minimum_valid_until_ts is not None:
|
||||||
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
|
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
|
||||||
query = {server: {key_id: arguments}}
|
query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
|
||||||
else:
|
else:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
def render_POST(self, request):
|
def render_POST(self, request):
|
||||||
|
@ -135,6 +136,7 @@ class RemoteKey(Resource):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||||
logger.info("Handling query for keys %r", query)
|
logger.info("Handling query for keys %r", query)
|
||||||
|
|
||||||
store_queries = []
|
store_queries = []
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -56,7 +56,7 @@ class ContentRepoResource(resource.Resource):
|
||||||
# servers.
|
# servers.
|
||||||
|
|
||||||
# TODO: A little crude here, we could do this better.
|
# TODO: A little crude here, we could do this better.
|
||||||
filename = request.path.split('/')[-1]
|
filename = request.path.decode('ascii').split('/')[-1]
|
||||||
# be paranoid
|
# be paranoid
|
||||||
filename = re.sub("[^0-9A-z.-_]", "", filename)
|
filename = re.sub("[^0-9A-z.-_]", "", filename)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ class ContentRepoResource(resource.Resource):
|
||||||
# select private. don't bother setting Expires as all our matrix
|
# select private. don't bother setting Expires as all our matrix
|
||||||
# clients are smart enough to be happy with Cache-Control (right?)
|
# clients are smart enough to be happy with Cache-Control (right?)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
"Cache-Control", "public,max-age=86400,s-maxage=86400"
|
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||||
)
|
)
|
||||||
|
|
||||||
d = FileSender().beginFileTransfer(f, request)
|
d = FileSender().beginFileTransfer(f, request)
|
||||||
|
|
|
@ -15,9 +15,8 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import urllib
|
|
||||||
|
|
||||||
from six.moves.urllib import parse as urlparse
|
from six.moves import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
@ -35,10 +34,15 @@ def parse_media_id(request):
|
||||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||||
# clients that parse the URL to see content type.
|
# clients that parse the URL to see content type.
|
||||||
server_name, media_id = request.postpath[:2]
|
server_name, media_id = request.postpath[:2]
|
||||||
|
|
||||||
|
if isinstance(server_name, bytes):
|
||||||
|
server_name = server_name.decode('utf-8')
|
||||||
|
media_id = media_id.decode('utf8')
|
||||||
|
|
||||||
file_name = None
|
file_name = None
|
||||||
if len(request.postpath) > 2:
|
if len(request.postpath) > 2:
|
||||||
try:
|
try:
|
||||||
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
|
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
pass
|
pass
|
||||||
return server_name, media_id, file_name
|
return server_name, media_id, file_name
|
||||||
|
@ -93,22 +97,18 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
||||||
file_size (int): Size in bytes of the media, if known.
|
file_size (int): Size in bytes of the media, if known.
|
||||||
upload_name (str): The name of the requested file, if any.
|
upload_name (str): The name of the requested file, if any.
|
||||||
"""
|
"""
|
||||||
|
def _quote(x):
|
||||||
|
return urllib.parse.quote(x.encode("utf-8"))
|
||||||
|
|
||||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||||
if upload_name:
|
if upload_name:
|
||||||
if is_ascii(upload_name):
|
if is_ascii(upload_name):
|
||||||
request.setHeader(
|
disposition = ("inline; filename=%s" % (_quote(upload_name),)).encode("ascii")
|
||||||
b"Content-Disposition",
|
|
||||||
b"inline; filename=%s" % (
|
|
||||||
urllib.quote(upload_name.encode("utf-8")),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
request.setHeader(
|
disposition = (
|
||||||
b"Content-Disposition",
|
"inline; filename*=utf-8''%s" % (_quote(upload_name),)).encode("ascii")
|
||||||
b"inline; filename*=utf-8''%s" % (
|
|
||||||
urllib.quote(upload_name.encode("utf-8")),
|
request.setHeader(b"Content-Disposition", disposition)
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# cache for at least a day.
|
# cache for at least a day.
|
||||||
# XXX: we might want to turn this off for data we don't want to
|
# XXX: we might want to turn this off for data we don't want to
|
||||||
|
|
|
@ -47,12 +47,12 @@ class DownloadResource(Resource):
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
"Content-Security-Policy",
|
b"Content-Security-Policy",
|
||||||
"default-src 'none';"
|
b"default-src 'none';"
|
||||||
" script-src 'none';"
|
b" script-src 'none';"
|
||||||
" plugin-types application/pdf;"
|
b" plugin-types application/pdf;"
|
||||||
" style-src 'unsafe-inline';"
|
b" style-src 'unsafe-inline';"
|
||||||
" object-src 'self';"
|
b" object-src 'self';"
|
||||||
)
|
)
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
|
|
|
@ -20,7 +20,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from six import iteritems
|
from six import PY3, iteritems
|
||||||
from six.moves.urllib import parse as urlparse
|
from six.moves.urllib import parse as urlparse
|
||||||
|
|
||||||
import twisted.internet.error
|
import twisted.internet.error
|
||||||
|
@ -397,13 +397,13 @@ class MediaRepository(object):
|
||||||
|
|
||||||
yield finish()
|
yield finish()
|
||||||
|
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers[b"Content-Type"][0].decode('ascii')
|
||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
content_disposition = headers.get("Content-Disposition", None)
|
content_disposition = headers.get(b"Content-Disposition", None)
|
||||||
if content_disposition:
|
if content_disposition:
|
||||||
_, params = cgi.parse_header(content_disposition[0],)
|
_, params = cgi.parse_header(content_disposition[0].decode('ascii'),)
|
||||||
upload_name = None
|
upload_name = None
|
||||||
|
|
||||||
# First check if there is a valid UTF-8 filename
|
# First check if there is a valid UTF-8 filename
|
||||||
|
@ -419,9 +419,13 @@ class MediaRepository(object):
|
||||||
upload_name = upload_name_ascii
|
upload_name = upload_name_ascii
|
||||||
|
|
||||||
if upload_name:
|
if upload_name:
|
||||||
upload_name = urlparse.unquote(upload_name)
|
if PY3:
|
||||||
|
upload_name = urlparse.unquote(upload_name)
|
||||||
|
else:
|
||||||
|
upload_name = urlparse.unquote(upload_name.encode('ascii'))
|
||||||
try:
|
try:
|
||||||
upload_name = upload_name.decode("utf-8")
|
if isinstance(upload_name, bytes):
|
||||||
|
upload_name = upload_name.decode("utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
upload_name = None
|
upload_name = None
|
||||||
else:
|
else:
|
||||||
|
@ -755,14 +759,15 @@ class MediaRepositoryResource(Resource):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
|
||||||
media_repo = hs.get_media_repository()
|
media_repo = hs.get_media_repository()
|
||||||
self.putChild("upload", UploadResource(hs, media_repo))
|
|
||||||
self.putChild("download", DownloadResource(hs, media_repo))
|
self.putChild(b"upload", UploadResource(hs, media_repo))
|
||||||
self.putChild("thumbnail", ThumbnailResource(
|
self.putChild(b"download", DownloadResource(hs, media_repo))
|
||||||
|
self.putChild(b"thumbnail", ThumbnailResource(
|
||||||
hs, media_repo, media_repo.media_storage,
|
hs, media_repo, media_repo.media_storage,
|
||||||
))
|
))
|
||||||
self.putChild("identicon", IdenticonResource())
|
self.putChild(b"identicon", IdenticonResource())
|
||||||
if hs.config.url_preview_enabled:
|
if hs.config.url_preview_enabled:
|
||||||
self.putChild("preview_url", PreviewUrlResource(
|
self.putChild(b"preview_url", PreviewUrlResource(
|
||||||
hs, media_repo, media_repo.media_storage,
|
hs, media_repo, media_repo.media_storage,
|
||||||
))
|
))
|
||||||
self.putChild("config", MediaConfigResource(hs))
|
self.putChild(b"config", MediaConfigResource(hs))
|
||||||
|
|
|
@ -261,7 +261,7 @@ class PreviewUrlResource(Resource):
|
||||||
|
|
||||||
logger.debug("Calculated OG for %s as %s" % (url, og))
|
logger.debug("Calculated OG for %s as %s" % (url, og))
|
||||||
|
|
||||||
jsonog = json.dumps(og)
|
jsonog = json.dumps(og).encode('utf8')
|
||||||
|
|
||||||
# store OG in history-aware DB cache
|
# store OG in history-aware DB cache
|
||||||
yield self.store.store_url_cache(
|
yield self.store.store_url_cache(
|
||||||
|
@ -301,20 +301,20 @@ class PreviewUrlResource(Resource):
|
||||||
logger.warn("Error downloading %s: %r", url, e)
|
logger.warn("Error downloading %s: %r", url, e)
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
500, "Failed to download content: %s" % (
|
500, "Failed to download content: %s" % (
|
||||||
traceback.format_exception_only(sys.exc_type, e),
|
traceback.format_exception_only(sys.exc_info()[0], e),
|
||||||
),
|
),
|
||||||
Codes.UNKNOWN,
|
Codes.UNKNOWN,
|
||||||
)
|
)
|
||||||
yield finish()
|
yield finish()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "Content-Type" in headers:
|
if b"Content-Type" in headers:
|
||||||
media_type = headers["Content-Type"][0]
|
media_type = headers[b"Content-Type"][0].decode('ascii')
|
||||||
else:
|
else:
|
||||||
media_type = "application/octet-stream"
|
media_type = "application/octet-stream"
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
content_disposition = headers.get("Content-Disposition", None)
|
content_disposition = headers.get(b"Content-Disposition", None)
|
||||||
if content_disposition:
|
if content_disposition:
|
||||||
_, params = cgi.parse_header(content_disposition[0],)
|
_, params = cgi.parse_header(content_disposition[0],)
|
||||||
download_name = None
|
download_name = None
|
||||||
|
|
Loading…
Reference in a new issue