Port over enough to get some sytests running on Python 3 (#3668)

This commit is contained in:
Amber Brown 2018-08-20 23:54:49 +10:00 committed by GitHub
parent cf6f9a8b53
commit 324525f40c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 91 additions and 40 deletions

1
changelog.d/3668.misc Normal file
View file

@ -0,0 +1 @@
Port over enough to Python 3 to allow the sytests to start.

View file

@ -211,7 +211,7 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", b"User-Agent",
default=[b""] default=[b""]
)[0] )[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr: if user and access_token and ip_addr:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user.to_string(), user_id=user.to_string(),
@ -682,7 +682,7 @@ class Auth(object):
Returns: Returns:
bool: False if no access_token was given, True otherwise. bool: False if no access_token was given, True otherwise.
""" """
query_params = request.args.get("access_token") query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning 401 since some of the old clients depended on auth errors returning
403. 403.
Returns: Returns:
str: The access_token unicode: The access_token
Raises: Raises:
AuthError: If there isn't an access_token in the request. AuthError: If there isn't an access_token in the request.
""" """
@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.", "Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN, errcode=Codes.MISSING_TOKEN,
) )
parts = auth_headers[0].split(" ") parts = auth_headers[0].split(b" ")
if parts[0] == "Bearer" and len(parts) == 2: if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1] return parts[1].decode('ascii')
else: else:
raise AuthError( raise AuthError(
token_not_found_http_status, token_not_found_http_status,
@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
return query_params[0] return query_params[0].decode('ascii')
@defer.inlineCallbacks @defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id): def check_in_room_or_world_readable(self, room_id, user_id):

View file

@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed return allowed, time_allowed
def prune_message_counts(self, time_now_s): def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys(): for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = ( message_count, time_start, msg_rate_hz = (
self.message_counts[user_id] self.message_counts[user_id]
) )

View file

@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file: if log_file:
# TODO: Customisable file size / backup count # TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
encoding='utf8'
) )
def sighup(signum, stack): def sighup(signum, stack):

View file

@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults default (int|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return int(args[name][0]) return int(args[name][0])
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults default (bool|None): value to use if the parameter is absent, defaults
to None. to None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
try: try:
return { return {
"true": True, b"true": True,
"false": False, b"false": False,
}[args[name][0]] }[args[name][0]]
except Exception: except Exception:
message = ( message = (
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False, def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
"""Parse a string parameter from the request query string. """
Parse a string parameter from the request query string.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (str): the name of the query parameter. name (bytes/unicode): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults default (bytes/unicode|None): value to use if the parameter is absent,
to None. defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string, allowed_values (list[bytes/unicode]): List of allowed values for the
or None if any value is allowed, defaults to None string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the name to, and decode the string
content with.
Returns: Returns:
str|None: A string value or the default. bytes/unicode|None: A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises: Raises:
SynapseError if the parameter is absent and required, or if the SynapseError if the parameter is absent and required, or if the
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args( return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type, request.args, name, default, required, allowed_values, param_type, encoding
) )
def parse_string_from_args(args, name, default=None, required=False, def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"): allowed_values=None, param_type="string", encoding='ascii'):
if not isinstance(name, bytes):
name = name.encode('ascii')
if name in args: if name in args:
value = args[name][0] value = args[name][0]
if encoding:
value = value.decode(encoding)
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
if encoding and isinstance(default, bytes):
return default.decode(encoding)
return default return default

View file

@ -235,7 +235,7 @@ class SynapseRequest(Request):
# need to decode as it could be raw utf-8 bytes # need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header # from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity authenticated_entity = self.authenticated_entity
if authenticated_entity is not None: if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace") authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header. # ...or could be raw utf-8 bytes in the User-Agent header.
@ -328,7 +328,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False) proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied) self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string self.server_version_string = server_version_string.encode('ascii')
def log(self, request): def log(self, request):
pass pass

View file

@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key str: A transaction key
""" """
token = self.auth.get_access_token_from_request(request) token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs): def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts """A helper function for fetch_or_execute which extracts

View file

@ -55,7 +55,7 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader(b"Content-Length").decode('ascii')
if content_length is None: if content_length is None:
raise SynapseError( raise SynapseError(
msg="Request must specify a Content-Length", code=400 msg="Request must specify a Content-Length", code=400
@ -66,10 +66,10 @@ class UploadResource(Resource):
code=413, code=413,
) )
upload_name = parse_string(request, "filename") upload_name = parse_string(request, b"filename", encoding=None)
if upload_name: if upload_name:
try: try:
upload_name = upload_name.decode('UTF-8') upload_name = upload_name.decode('utf8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@ -78,8 +78,8 @@ class UploadResource(Resource):
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader("Content-Type"): if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0] media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
else: else:
raise SynapseError( raise SynapseError(
msg="Upload request missing 'Content-Type'", msg="Upload request missing 'Content-Type'",

View file

@ -38,4 +38,4 @@ else:
return os.urandom(nbytes) return os.urandom(nbytes)
def token_hex(self, nbytes=32): def token_hex(self, nbytes=32):
return binascii.hexlify(self.token_bytes(nbytes)) return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')

View file

@ -20,6 +20,8 @@ import time
from functools import wraps from functools import wraps
from inspect import getcallargs from inspect import getcallargs
from six import PY3
_TIME_FUNC_ID = 0 _TIME_FUNC_ID = 0
@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name) logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
lineno = f.func_code.co_firstlineno if PY3:
pathname = f.func_code.co_filename lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
else:
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
record = logging.LogRecord( record = logging.LogRecord(
name=name, name=name,

View file

@ -16,6 +16,7 @@
import random import random
import string import string
from six import PY3
from six.moves import range from six.moves import range
_string_with_symbols = ( _string_with_symbols = (
@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s): def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
s.decode('ascii').encode('ascii')
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
try: try:
s.encode("ascii") s.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None. If given None then will return None.
""" """
if PY3:
return s
if s is None: if s is None:
return None return None

View file

@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'], ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
git_branch = "b=" + git_branch git_branch = "b=" + git_branch
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_branch = "" git_branch = ""
@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'], ['git', 'describe', '--exact-match'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
git_tag = "t=" + git_tag git_tag = "t=" + git_tag
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_tag = "" git_tag = ""
@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'], ['git', 'rev-parse', '--short', 'HEAD'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip().decode('ascii')
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_commit = "" git_commit = ""
@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string], ['git', 'describe', '--dirty=' + dirty_string],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip().endswith(dirty_string) ).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else "" git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % ( "%s (%s)" % (
module.__version__, git_version, module.__version__, git_version,
) )
).encode("ascii") )
except Exception as e: except Exception as e:
logger.info("Failed to check for git repository: %s", e) logger.info("Failed to check for git repository: %s", e)
return module.__version__.encode("ascii") return module.__version__