forked from MirrorHub/synapse
Port over enough to get some sytests running on Python 3 (#3668)
This commit is contained in:
parent
cf6f9a8b53
commit
324525f40c
12 changed files with 91 additions and 40 deletions
1
changelog.d/3668.misc
Normal file
1
changelog.d/3668.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Port over enough to Python 3 to allow the sytests to start.
|
|
@ -211,7 +211,7 @@ class Auth(object):
|
||||||
user_agent = request.requestHeaders.getRawHeaders(
|
user_agent = request.requestHeaders.getRawHeaders(
|
||||||
b"User-Agent",
|
b"User-Agent",
|
||||||
default=[b""]
|
default=[b""]
|
||||||
)[0]
|
)[0].decode('ascii', 'surrogateescape')
|
||||||
if user and access_token and ip_addr:
|
if user and access_token and ip_addr:
|
||||||
yield self.store.insert_client_ip(
|
yield self.store.insert_client_ip(
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
|
@ -682,7 +682,7 @@ class Auth(object):
|
||||||
Returns:
|
Returns:
|
||||||
bool: False if no access_token was given, True otherwise.
|
bool: False if no access_token was given, True otherwise.
|
||||||
"""
|
"""
|
||||||
query_params = request.args.get("access_token")
|
query_params = request.args.get(b"access_token")
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
return bool(query_params) or bool(auth_headers)
|
return bool(query_params) or bool(auth_headers)
|
||||||
|
|
||||||
|
@ -698,7 +698,7 @@ class Auth(object):
|
||||||
401 since some of the old clients depended on auth errors returning
|
401 since some of the old clients depended on auth errors returning
|
||||||
403.
|
403.
|
||||||
Returns:
|
Returns:
|
||||||
str: The access_token
|
unicode: The access_token
|
||||||
Raises:
|
Raises:
|
||||||
AuthError: If there isn't an access_token in the request.
|
AuthError: If there isn't an access_token in the request.
|
||||||
"""
|
"""
|
||||||
|
@ -720,9 +720,9 @@ class Auth(object):
|
||||||
"Too many Authorization headers.",
|
"Too many Authorization headers.",
|
||||||
errcode=Codes.MISSING_TOKEN,
|
errcode=Codes.MISSING_TOKEN,
|
||||||
)
|
)
|
||||||
parts = auth_headers[0].split(" ")
|
parts = auth_headers[0].split(b" ")
|
||||||
if parts[0] == "Bearer" and len(parts) == 2:
|
if parts[0] == b"Bearer" and len(parts) == 2:
|
||||||
return parts[1]
|
return parts[1].decode('ascii')
|
||||||
else:
|
else:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
token_not_found_http_status,
|
token_not_found_http_status,
|
||||||
|
@ -738,7 +738,7 @@ class Auth(object):
|
||||||
errcode=Codes.MISSING_TOKEN
|
errcode=Codes.MISSING_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
return query_params[0]
|
return query_params[0].decode('ascii')
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_in_room_or_world_readable(self, room_id, user_id):
|
def check_in_room_or_world_readable(self, room_id, user_id):
|
||||||
|
|
|
@ -72,7 +72,7 @@ class Ratelimiter(object):
|
||||||
return allowed, time_allowed
|
return allowed, time_allowed
|
||||||
|
|
||||||
def prune_message_counts(self, time_now_s):
|
def prune_message_counts(self, time_now_s):
|
||||||
for user_id in self.message_counts.keys():
|
for user_id in list(self.message_counts.keys()):
|
||||||
message_count, time_start, msg_rate_hz = (
|
message_count, time_start, msg_rate_hz = (
|
||||||
self.message_counts[user_id]
|
self.message_counts[user_id]
|
||||||
)
|
)
|
||||||
|
|
|
@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
|
||||||
if log_file:
|
if log_file:
|
||||||
# TODO: Customisable file size / backup count
|
# TODO: Customisable file size / backup count
|
||||||
handler = logging.handlers.RotatingFileHandler(
|
handler = logging.handlers.RotatingFileHandler(
|
||||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
|
||||||
|
encoding='utf8'
|
||||||
)
|
)
|
||||||
|
|
||||||
def sighup(signum, stack):
|
def sighup(signum, stack):
|
||||||
|
|
|
@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the twisted HTTP request.
|
request: the twisted HTTP request.
|
||||||
name (str): the name of the query parameter.
|
name (bytes/unicode): the name of the query parameter.
|
||||||
default (int|None): value to use if the parameter is absent, defaults
|
default (int|None): value to use if the parameter is absent, defaults
|
||||||
to None.
|
to None.
|
||||||
required (bool): whether to raise a 400 SynapseError if the
|
required (bool): whether to raise a 400 SynapseError if the
|
||||||
|
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
|
||||||
|
|
||||||
|
|
||||||
def parse_integer_from_args(args, name, default=None, required=False):
|
def parse_integer_from_args(args, name, default=None, required=False):
|
||||||
|
|
||||||
|
if not isinstance(name, bytes):
|
||||||
|
name = name.encode('ascii')
|
||||||
|
|
||||||
if name in args:
|
if name in args:
|
||||||
try:
|
try:
|
||||||
return int(args[name][0])
|
return int(args[name][0])
|
||||||
|
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the twisted HTTP request.
|
request: the twisted HTTP request.
|
||||||
name (str): the name of the query parameter.
|
name (bytes/unicode): the name of the query parameter.
|
||||||
default (bool|None): value to use if the parameter is absent, defaults
|
default (bool|None): value to use if the parameter is absent, defaults
|
||||||
to None.
|
to None.
|
||||||
required (bool): whether to raise a 400 SynapseError if the
|
required (bool): whether to raise a 400 SynapseError if the
|
||||||
|
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
|
||||||
|
|
||||||
|
|
||||||
def parse_boolean_from_args(args, name, default=None, required=False):
|
def parse_boolean_from_args(args, name, default=None, required=False):
|
||||||
|
|
||||||
|
if not isinstance(name, bytes):
|
||||||
|
name = name.encode('ascii')
|
||||||
|
|
||||||
if name in args:
|
if name in args:
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"true": True,
|
b"true": True,
|
||||||
"false": False,
|
b"false": False,
|
||||||
}[args[name][0]]
|
}[args[name][0]]
|
||||||
except Exception:
|
except Exception:
|
||||||
message = (
|
message = (
|
||||||
|
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
|
||||||
|
|
||||||
|
|
||||||
def parse_string(request, name, default=None, required=False,
|
def parse_string(request, name, default=None, required=False,
|
||||||
allowed_values=None, param_type="string"):
|
allowed_values=None, param_type="string", encoding='ascii'):
|
||||||
"""Parse a string parameter from the request query string.
|
"""
|
||||||
|
Parse a string parameter from the request query string.
|
||||||
|
|
||||||
|
If encoding is not None, the content of the query param will be
|
||||||
|
decoded to Unicode using the encoding, otherwise it will be encoded
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: the twisted HTTP request.
|
request: the twisted HTTP request.
|
||||||
name (str): the name of the query parameter.
|
name (bytes/unicode): the name of the query parameter.
|
||||||
default (str|None): value to use if the parameter is absent, defaults
|
default (bytes/unicode|None): value to use if the parameter is absent,
|
||||||
to None.
|
defaults to None. Must be bytes if encoding is None.
|
||||||
required (bool): whether to raise a 400 SynapseError if the
|
required (bool): whether to raise a 400 SynapseError if the
|
||||||
parameter is absent, defaults to False.
|
parameter is absent, defaults to False.
|
||||||
allowed_values (list[str]): List of allowed values for the string,
|
allowed_values (list[bytes/unicode]): List of allowed values for the
|
||||||
or None if any value is allowed, defaults to None
|
string, or None if any value is allowed, defaults to None. Must be
|
||||||
|
the same type as name, if given.
|
||||||
|
encoding: The encoding to decode the name to, and decode the string
|
||||||
|
content with.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str|None: A string value or the default.
|
bytes/unicode|None: A string value or the default. Unicode if encoding
|
||||||
|
was given, bytes otherwise.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if the parameter is absent and required, or if the
|
SynapseError if the parameter is absent and required, or if the
|
||||||
|
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
|
||||||
is not one of those allowed values.
|
is not one of those allowed values.
|
||||||
"""
|
"""
|
||||||
return parse_string_from_args(
|
return parse_string_from_args(
|
||||||
request.args, name, default, required, allowed_values, param_type,
|
request.args, name, default, required, allowed_values, param_type, encoding
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_string_from_args(args, name, default=None, required=False,
|
def parse_string_from_args(args, name, default=None, required=False,
|
||||||
allowed_values=None, param_type="string"):
|
allowed_values=None, param_type="string", encoding='ascii'):
|
||||||
|
|
||||||
|
if not isinstance(name, bytes):
|
||||||
|
name = name.encode('ascii')
|
||||||
|
|
||||||
if name in args:
|
if name in args:
|
||||||
value = args[name][0]
|
value = args[name][0]
|
||||||
|
|
||||||
|
if encoding:
|
||||||
|
value = value.decode(encoding)
|
||||||
|
|
||||||
if allowed_values is not None and value not in allowed_values:
|
if allowed_values is not None and value not in allowed_values:
|
||||||
message = "Query parameter %r must be one of [%s]" % (
|
message = "Query parameter %r must be one of [%s]" % (
|
||||||
name, ", ".join(repr(v) for v in allowed_values)
|
name, ", ".join(repr(v) for v in allowed_values)
|
||||||
|
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
|
||||||
message = "Missing %s query parameter %r" % (param_type, name)
|
message = "Missing %s query parameter %r" % (param_type, name)
|
||||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
if encoding and isinstance(default, bytes):
|
||||||
|
return default.decode(encoding)
|
||||||
|
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -235,7 +235,7 @@ class SynapseRequest(Request):
|
||||||
# need to decode as it could be raw utf-8 bytes
|
# need to decode as it could be raw utf-8 bytes
|
||||||
# from a IDN servname in an auth header
|
# from a IDN servname in an auth header
|
||||||
authenticated_entity = self.authenticated_entity
|
authenticated_entity = self.authenticated_entity
|
||||||
if authenticated_entity is not None:
|
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
||||||
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
||||||
|
|
||||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
# ...or could be raw utf-8 bytes in the User-Agent header.
|
||||||
|
@ -328,7 +328,7 @@ class SynapseSite(Site):
|
||||||
proxied = config.get("x_forwarded", False)
|
proxied = config.get("x_forwarded", False)
|
||||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||||
self.access_logger = logging.getLogger(logger_name)
|
self.access_logger = logging.getLogger(logger_name)
|
||||||
self.server_version_string = server_version_string
|
self.server_version_string = server_version_string.encode('ascii')
|
||||||
|
|
||||||
def log(self, request):
|
def log(self, request):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -53,7 +53,7 @@ class HttpTransactionCache(object):
|
||||||
str: A transaction key
|
str: A transaction key
|
||||||
"""
|
"""
|
||||||
token = self.auth.get_access_token_from_request(request)
|
token = self.auth.get_access_token_from_request(request)
|
||||||
return request.path + "/" + token
|
return request.path.decode('utf8') + "/" + token
|
||||||
|
|
||||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||||
"""A helper function for fetch_or_execute which extracts
|
"""A helper function for fetch_or_execute which extracts
|
||||||
|
|
|
@ -55,7 +55,7 @@ class UploadResource(Resource):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
content_length = request.getHeader("Content-Length")
|
content_length = request.getHeader(b"Content-Length").decode('ascii')
|
||||||
if content_length is None:
|
if content_length is None:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
msg="Request must specify a Content-Length", code=400
|
msg="Request must specify a Content-Length", code=400
|
||||||
|
@ -66,10 +66,10 @@ class UploadResource(Resource):
|
||||||
code=413,
|
code=413,
|
||||||
)
|
)
|
||||||
|
|
||||||
upload_name = parse_string(request, "filename")
|
upload_name = parse_string(request, b"filename", encoding=None)
|
||||||
if upload_name:
|
if upload_name:
|
||||||
try:
|
try:
|
||||||
upload_name = upload_name.decode('UTF-8')
|
upload_name = upload_name.decode('utf8')
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
||||||
|
@ -78,8 +78,8 @@ class UploadResource(Resource):
|
||||||
|
|
||||||
headers = request.requestHeaders
|
headers = request.requestHeaders
|
||||||
|
|
||||||
if headers.hasHeader("Content-Type"):
|
if headers.hasHeader(b"Content-Type"):
|
||||||
media_type = headers.getRawHeaders(b"Content-Type")[0]
|
media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
|
||||||
else:
|
else:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
msg="Upload request missing 'Content-Type'",
|
msg="Upload request missing 'Content-Type'",
|
||||||
|
|
|
@ -38,4 +38,4 @@ else:
|
||||||
return os.urandom(nbytes)
|
return os.urandom(nbytes)
|
||||||
|
|
||||||
def token_hex(self, nbytes=32):
|
def token_hex(self, nbytes=32):
|
||||||
return binascii.hexlify(self.token_bytes(nbytes))
|
return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
|
||||||
|
|
|
@ -20,6 +20,8 @@ import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from inspect import getcallargs
|
from inspect import getcallargs
|
||||||
|
|
||||||
|
from six import PY3
|
||||||
|
|
||||||
_TIME_FUNC_ID = 0
|
_TIME_FUNC_ID = 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
lineno = f.func_code.co_firstlineno
|
if PY3:
|
||||||
pathname = f.func_code.co_filename
|
lineno = f.__code__.co_firstlineno
|
||||||
|
pathname = f.__code__.co_filename
|
||||||
|
else:
|
||||||
|
lineno = f.func_code.co_firstlineno
|
||||||
|
pathname = f.func_code.co_filename
|
||||||
|
|
||||||
record = logging.LogRecord(
|
record = logging.LogRecord(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
from six import PY3
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
_string_with_symbols = (
|
_string_with_symbols = (
|
||||||
|
@ -34,6 +35,17 @@ def random_string_with_symbols(length):
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s):
|
def is_ascii(s):
|
||||||
|
|
||||||
|
if PY3:
|
||||||
|
if isinstance(s, bytes):
|
||||||
|
try:
|
||||||
|
s.decode('ascii').encode('ascii')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return False
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s.encode("ascii")
|
s.encode("ascii")
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
|
@ -49,6 +61,9 @@ def to_ascii(s):
|
||||||
|
|
||||||
If given None then will return None.
|
If given None then will return None.
|
||||||
"""
|
"""
|
||||||
|
if PY3:
|
||||||
|
return s
|
||||||
|
|
||||||
if s is None:
|
if s is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ def get_version_string(module):
|
||||||
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
||||||
stderr=null,
|
stderr=null,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
).strip()
|
).strip().decode('ascii')
|
||||||
git_branch = "b=" + git_branch
|
git_branch = "b=" + git_branch
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
git_branch = ""
|
git_branch = ""
|
||||||
|
@ -40,7 +40,7 @@ def get_version_string(module):
|
||||||
['git', 'describe', '--exact-match'],
|
['git', 'describe', '--exact-match'],
|
||||||
stderr=null,
|
stderr=null,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
).strip()
|
).strip().decode('ascii')
|
||||||
git_tag = "t=" + git_tag
|
git_tag = "t=" + git_tag
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
git_tag = ""
|
git_tag = ""
|
||||||
|
@ -50,7 +50,7 @@ def get_version_string(module):
|
||||||
['git', 'rev-parse', '--short', 'HEAD'],
|
['git', 'rev-parse', '--short', 'HEAD'],
|
||||||
stderr=null,
|
stderr=null,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
).strip()
|
).strip().decode('ascii')
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
git_commit = ""
|
git_commit = ""
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def get_version_string(module):
|
||||||
['git', 'describe', '--dirty=' + dirty_string],
|
['git', 'describe', '--dirty=' + dirty_string],
|
||||||
stderr=null,
|
stderr=null,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
).strip().endswith(dirty_string)
|
).strip().decode('ascii').endswith(dirty_string)
|
||||||
|
|
||||||
git_dirty = "dirty" if is_dirty else ""
|
git_dirty = "dirty" if is_dirty else ""
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
|
@ -77,8 +77,8 @@ def get_version_string(module):
|
||||||
"%s (%s)" % (
|
"%s (%s)" % (
|
||||||
module.__version__, git_version,
|
module.__version__, git_version,
|
||||||
)
|
)
|
||||||
).encode("ascii")
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info("Failed to check for git repository: %s", e)
|
logger.info("Failed to check for git repository: %s", e)
|
||||||
|
|
||||||
return module.__version__.encode("ascii")
|
return module.__version__
|
||||||
|
|
Loading…
Reference in a new issue