forked from MirrorHub/synapse
Port over enough to get some sytests running on Python 3 (#3668)
This commit is contained in:
parent
cf6f9a8b53
commit
324525f40c
12 changed files with 91 additions and 40 deletions
1
changelog.d/3668.misc
Normal file
1
changelog.d/3668.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Port over enough to Python 3 to allow the sytests to start.
|
|
@ -211,7 +211,7 @@ class Auth(object):
|
|||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
b"User-Agent",
|
||||
default=[b""]
|
||||
)[0]
|
||||
)[0].decode('ascii', 'surrogateescape')
|
||||
if user and access_token and ip_addr:
|
||||
yield self.store.insert_client_ip(
|
||||
user_id=user.to_string(),
|
||||
|
@ -682,7 +682,7 @@ class Auth(object):
|
|||
Returns:
|
||||
bool: False if no access_token was given, True otherwise.
|
||||
"""
|
||||
query_params = request.args.get("access_token")
|
||||
query_params = request.args.get(b"access_token")
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
return bool(query_params) or bool(auth_headers)
|
||||
|
||||
|
@ -698,7 +698,7 @@ class Auth(object):
|
|||
401 since some of the old clients depended on auth errors returning
|
||||
403.
|
||||
Returns:
|
||||
str: The access_token
|
||||
unicode: The access_token
|
||||
Raises:
|
||||
AuthError: If there isn't an access_token in the request.
|
||||
"""
|
||||
|
@ -720,9 +720,9 @@ class Auth(object):
|
|||
"Too many Authorization headers.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
parts = auth_headers[0].split(" ")
|
||||
if parts[0] == "Bearer" and len(parts) == 2:
|
||||
return parts[1]
|
||||
parts = auth_headers[0].split(b" ")
|
||||
if parts[0] == b"Bearer" and len(parts) == 2:
|
||||
return parts[1].decode('ascii')
|
||||
else:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
|
@ -738,7 +738,7 @@ class Auth(object):
|
|||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
|
||||
return query_params[0]
|
||||
return query_params[0].decode('ascii')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_in_room_or_world_readable(self, room_id, user_id):
|
||||
|
|
|
@ -72,7 +72,7 @@ class Ratelimiter(object):
|
|||
return allowed, time_allowed
|
||||
|
||||
def prune_message_counts(self, time_now_s):
|
||||
for user_id in self.message_counts.keys():
|
||||
for user_id in list(self.message_counts.keys()):
|
||||
message_count, time_start, msg_rate_hz = (
|
||||
self.message_counts[user_id]
|
||||
)
|
||||
|
|
|
@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
|
|||
if log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
|
||||
encoding='utf8'
|
||||
)
|
||||
|
||||
def sighup(signum, stack):
|
||||
|
|
|
@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
|
|||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
name (str): the name of the query parameter.
|
||||
name (bytes/unicode): the name of the query parameter.
|
||||
default (int|None): value to use if the parameter is absent, defaults
|
||||
to None.
|
||||
required (bool): whether to raise a 400 SynapseError if the
|
||||
|
@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
|
|||
|
||||
|
||||
def parse_integer_from_args(args, name, default=None, required=False):
|
||||
|
||||
if not isinstance(name, bytes):
|
||||
name = name.encode('ascii')
|
||||
|
||||
if name in args:
|
||||
try:
|
||||
return int(args[name][0])
|
||||
|
@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
|
|||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
name (str): the name of the query parameter.
|
||||
name (bytes/unicode): the name of the query parameter.
|
||||
default (bool|None): value to use if the parameter is absent, defaults
|
||||
to None.
|
||||
required (bool): whether to raise a 400 SynapseError if the
|
||||
|
@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
|
|||
|
||||
|
||||
def parse_boolean_from_args(args, name, default=None, required=False):
|
||||
|
||||
if not isinstance(name, bytes):
|
||||
name = name.encode('ascii')
|
||||
|
||||
if name in args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
b"true": True,
|
||||
b"false": False,
|
||||
}[args[name][0]]
|
||||
except Exception:
|
||||
message = (
|
||||
|
@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
|
|||
|
||||
|
||||
def parse_string(request, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
"""Parse a string parameter from the request query string.
|
||||
allowed_values=None, param_type="string", encoding='ascii'):
|
||||
"""
|
||||
Parse a string parameter from the request query string.
|
||||
|
||||
If encoding is not None, the content of the query param will be
|
||||
decoded to Unicode using the encoding, otherwise it will be encoded
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
name (str): the name of the query parameter.
|
||||
default (str|None): value to use if the parameter is absent, defaults
|
||||
to None.
|
||||
name (bytes/unicode): the name of the query parameter.
|
||||
default (bytes/unicode|None): value to use if the parameter is absent,
|
||||
defaults to None. Must be bytes if encoding is None.
|
||||
required (bool): whether to raise a 400 SynapseError if the
|
||||
parameter is absent, defaults to False.
|
||||
allowed_values (list[str]): List of allowed values for the string,
|
||||
or None if any value is allowed, defaults to None
|
||||
allowed_values (list[bytes/unicode]): List of allowed values for the
|
||||
string, or None if any value is allowed, defaults to None. Must be
|
||||
the same type as name, if given.
|
||||
encoding: The encoding to decode the name to, and decode the string
|
||||
content with.
|
||||
|
||||
Returns:
|
||||
str|None: A string value or the default.
|
||||
bytes/unicode|None: A string value or the default. Unicode if encoding
|
||||
was given, bytes otherwise.
|
||||
|
||||
Raises:
|
||||
SynapseError if the parameter is absent and required, or if the
|
||||
|
@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
|
|||
is not one of those allowed values.
|
||||
"""
|
||||
return parse_string_from_args(
|
||||
request.args, name, default, required, allowed_values, param_type,
|
||||
request.args, name, default, required, allowed_values, param_type, encoding
|
||||
)
|
||||
|
||||
|
||||
def parse_string_from_args(args, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
allowed_values=None, param_type="string", encoding='ascii'):
|
||||
|
||||
if not isinstance(name, bytes):
|
||||
name = name.encode('ascii')
|
||||
|
||||
if name in args:
|
||||
value = args[name][0]
|
||||
|
||||
if encoding:
|
||||
value = value.decode(encoding)
|
||||
|
||||
if allowed_values is not None and value not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name, ", ".join(repr(v) for v in allowed_values)
|
||||
|
@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
|
|||
message = "Missing %s query parameter %r" % (param_type, name)
|
||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||
else:
|
||||
|
||||
if encoding and isinstance(default, bytes):
|
||||
return default.decode(encoding)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
|
|
|
@ -235,7 +235,7 @@ class SynapseRequest(Request):
|
|||
# need to decode as it could be raw utf-8 bytes
|
||||
# from a IDN servname in an auth header
|
||||
authenticated_entity = self.authenticated_entity
|
||||
if authenticated_entity is not None:
|
||||
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
||||
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
||||
|
||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
||||
|
@ -328,7 +328,7 @@ class SynapseSite(Site):
|
|||
proxied = config.get("x_forwarded", False)
|
||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
self.server_version_string = server_version_string
|
||||
self.server_version_string = server_version_string.encode('ascii')
|
||||
|
||||
def log(self, request):
|
||||
pass
|
||||
|
|
|
@ -53,7 +53,7 @@ class HttpTransactionCache(object):
|
|||
str: A transaction key
|
||||
"""
|
||||
token = self.auth.get_access_token_from_request(request)
|
||||
return request.path + "/" + token
|
||||
return request.path.decode('utf8') + "/" + token
|
||||
|
||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||
"""A helper function for fetch_or_execute which extracts
|
||||
|
|
|
@ -55,7 +55,7 @@ class UploadResource(Resource):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
content_length = request.getHeader(b"Content-Length").decode('ascii')
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
|
@ -66,10 +66,10 @@ class UploadResource(Resource):
|
|||
code=413,
|
||||
)
|
||||
|
||||
upload_name = parse_string(request, "filename")
|
||||
upload_name = parse_string(request, b"filename", encoding=None)
|
||||
if upload_name:
|
||||
try:
|
||||
upload_name = upload_name.decode('UTF-8')
|
||||
upload_name = upload_name.decode('utf8')
|
||||
except UnicodeDecodeError:
|
||||
raise SynapseError(
|
||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
||||
|
@ -78,8 +78,8 @@ class UploadResource(Resource):
|
|||
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader("Content-Type"):
|
||||
media_type = headers.getRawHeaders(b"Content-Type")[0]
|
||||
if headers.hasHeader(b"Content-Type"):
|
||||
media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
|
||||
else:
|
||||
raise SynapseError(
|
||||
msg="Upload request missing 'Content-Type'",
|
||||
|
|
|
@ -38,4 +38,4 @@ else:
|
|||
return os.urandom(nbytes)
|
||||
|
||||
def token_hex(self, nbytes=32):
|
||||
return binascii.hexlify(self.token_bytes(nbytes))
|
||||
return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
|
||||
|
|
|
@ -20,6 +20,8 @@ import time
|
|||
from functools import wraps
|
||||
from inspect import getcallargs
|
||||
|
||||
from six import PY3
|
||||
|
||||
_TIME_FUNC_ID = 0
|
||||
|
||||
|
||||
|
@ -28,6 +30,10 @@ def _log_debug_as_f(f, msg, msg_args):
|
|||
logger = logging.getLogger(name)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if PY3:
|
||||
lineno = f.__code__.co_firstlineno
|
||||
pathname = f.__code__.co_filename
|
||||
else:
|
||||
lineno = f.func_code.co_firstlineno
|
||||
pathname = f.func_code.co_filename
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import random
|
||||
import string
|
||||
|
||||
from six import PY3
|
||||
from six.moves import range
|
||||
|
||||
_string_with_symbols = (
|
||||
|
@ -34,6 +35,17 @@ def random_string_with_symbols(length):
|
|||
|
||||
|
||||
def is_ascii(s):
|
||||
|
||||
if PY3:
|
||||
if isinstance(s, bytes):
|
||||
try:
|
||||
s.decode('ascii').encode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
s.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
|
@ -49,6 +61,9 @@ def to_ascii(s):
|
|||
|
||||
If given None then will return None.
|
||||
"""
|
||||
if PY3:
|
||||
return s
|
||||
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ def get_version_string(module):
|
|||
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
||||
stderr=null,
|
||||
cwd=cwd,
|
||||
).strip()
|
||||
).strip().decode('ascii')
|
||||
git_branch = "b=" + git_branch
|
||||
except subprocess.CalledProcessError:
|
||||
git_branch = ""
|
||||
|
@ -40,7 +40,7 @@ def get_version_string(module):
|
|||
['git', 'describe', '--exact-match'],
|
||||
stderr=null,
|
||||
cwd=cwd,
|
||||
).strip()
|
||||
).strip().decode('ascii')
|
||||
git_tag = "t=" + git_tag
|
||||
except subprocess.CalledProcessError:
|
||||
git_tag = ""
|
||||
|
@ -50,7 +50,7 @@ def get_version_string(module):
|
|||
['git', 'rev-parse', '--short', 'HEAD'],
|
||||
stderr=null,
|
||||
cwd=cwd,
|
||||
).strip()
|
||||
).strip().decode('ascii')
|
||||
except subprocess.CalledProcessError:
|
||||
git_commit = ""
|
||||
|
||||
|
@ -60,7 +60,7 @@ def get_version_string(module):
|
|||
['git', 'describe', '--dirty=' + dirty_string],
|
||||
stderr=null,
|
||||
cwd=cwd,
|
||||
).strip().endswith(dirty_string)
|
||||
).strip().decode('ascii').endswith(dirty_string)
|
||||
|
||||
git_dirty = "dirty" if is_dirty else ""
|
||||
except subprocess.CalledProcessError:
|
||||
|
@ -77,8 +77,8 @@ def get_version_string(module):
|
|||
"%s (%s)" % (
|
||||
module.__version__, git_version,
|
||||
)
|
||||
).encode("ascii")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("Failed to check for git repository: %s", e)
|
||||
|
||||
return module.__version__.encode("ascii")
|
||||
return module.__version__
|
||||
|
|
Loading…
Reference in a new issue