Python 3: Convert some unicode/bytes uses (#3569)

This commit is contained in:
Amber Brown 2018-08-02 00:54:06 +10:00 committed by GitHub
parent c4842e16cb
commit da7785147d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 122 additions and 67 deletions

1
changelog.d/3569.bugfix Normal file
View file

@ -0,0 +1 @@
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.

View file

@ -252,10 +252,10 @@ class Auth(object):
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
defer.returnValue((None, None)) defer.returnValue((None, None))
if "user_id" not in request.args: if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))
user_id = request.args["user_id"][0] user_id = request.args[b"user_id"][0].decode('utf8')
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service)) defer.returnValue((app_service.sender, app_service))

View file

@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
param_dict = dict(kv.split("=") for kv in params) param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value): def strip_quotes(value):
if value.startswith(b"\""): if value.startswith("\""):
return value[1:-1] return value[1:-1]
else: else:
return value return value

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import unicodedata
import attr import attr
import bcrypt import bcrypt
@ -626,6 +627,7 @@ class AuthHandler(BaseHandler):
# special case to check for "password" for the check_password interface # special case to check for "password" for the check_password interface
# for the auth providers # for the auth providers
password = login_submission.get("password") password = login_submission.get("password")
if login_type == LoginType.PASSWORD: if login_type == LoginType.PASSWORD:
if not self._password_enabled: if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.") raise SynapseError(400, "Password login has been disabled.")
@ -707,9 +709,10 @@ class AuthHandler(BaseHandler):
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (unicode): complete @user:id
password (unicode): the provided password
Returns: Returns:
(str) the canonical_user_id, or None if unknown user / bad password (unicode) the canonical_user_id, or None if unknown user / bad password
""" """
lookupres = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres: if not lookupres:
@ -849,14 +852,19 @@ class AuthHandler(BaseHandler):
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
Returns: Returns:
Deferred(str): Hashed password. Deferred(unicode): Hashed password.
""" """
def _do_hash(): def _do_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, # Normalise the Unicode in the password
bcrypt.gensalt(self.bcrypt_rounds)) pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')
return make_deferred_yieldable( return make_deferred_yieldable(
threads.deferToThreadPool( threads.deferToThreadPool(
@ -868,16 +876,19 @@ class AuthHandler(BaseHandler):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
password (str): Password to hash. password (unicode): Password to hash.
stored_hash (str): Expected hash value. stored_hash (unicode): Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
password.encode('utf8') + self.hs.config.password_pepper, pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8') stored_hash.encode('utf8')
) )

View file

@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be generate_token (bool): Whether a new access token should be

View file

@ -13,12 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi import cgi
import collections import collections
import logging import logging
import urllib
from six.moves import http_client from six import PY3
from six.moves import http_client, urllib
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
method = method.encode("utf-8") # method is bytes on py3
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler. # installed by @request_handler.
def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: _unquote(value) if value else value
for name, value in group_dict.items() for name, value in group_dict.items()
}) })
@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request): request (twisted.web.http.Request):
Returns: Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict Tuple[Callable, dict[unicode, unicode]]: callback method, and the
mapping keys to path components as specified in the handler's dict mapping keys to path components as specified in the
path match regexp. handler's path match regexp.
The callback will normally be a method registered via The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either register_paths, so will return (possibly via Deferred) either
@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path.decode('ascii'))
if m: if m:
# We found a match! # We found a match!
return path_entry.callback, m.groupdict() return path_entry.callback, m.groupdict()
@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path self.url = path
def render_GET(self, request): def render_GET(self, request):
return redirectTo(self.url, request) return redirectTo(self.url.encode('ascii'), request)
def getChild(self, name, request): def getChild(self, name, request):
if len(name) == 0: if len(name) == 0:
@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
return return
if pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n" json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
else: else:
if canonical_json or synapse.events.USE_FROZEN_DICTS: if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
else: else:
json_bytes = json.dumps(json_object) json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes( return respond_with_json_bytes(
request, code, json_bytes, request, code, json_bytes,

View file

@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
if not content_bytes and allow_empty_body: if not content_bytes and allow_empty_body:
return None return None
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try: try:
content = json.loads(content_bytes) content_unicode = content_bytes.decode('utf8')
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e: except Exception as e:
logger.warn("Unable to parse JSON: %s", e) logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -18,6 +18,7 @@ import hashlib
import hmac import hmac
import logging import logging
from six import text_type
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer from twisted.internet import defer
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON, 400, "username must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['username'], str) or len(body['username']) > 512): if (
not isinstance(body['username'], text_type)
or len(body['username']) > 512
):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8") username = body["username"].encode("utf-8")
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
400, "password must be specified", errcode=Codes.BAD_JSON, 400, "password must be specified", errcode=Codes.BAD_JSON,
) )
else: else:
if (not isinstance(body['password'], str) or len(body['password']) > 512): if (
not isinstance(body['password'], text_type)
or len(body['password']) > 512
):
raise SynapseError(400, "Invalid password") raise SynapseError(400, "Invalid password")
password = body["password"].encode("utf-8") password = body["password"].encode("utf-8")
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin") want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest() want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac): if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
raise SynapseError( raise SynapseError(403, "HMAC incorrect")
403, "HMAC incorrect",
)
# Reuse the parts of RegisterRestServlet to reduce code duplication # Reuse the parts of RegisterRestServlet to reduce code duplication
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
register = RegisterRestServlet(self.hs) register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register( (user_id, _) = yield register.registration_handler.register(
localpart=username.lower(), password=password, admin=bool(admin), localpart=body['username'].lower(),
password=body["password"],
admin=bool(admin),
generate_token=False, generate_token=False,
) )

View file

@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
kind = "user" kind = b"user"
if "kind" in request.args: if b"kind" in request.args:
kind = request.args["kind"][0] kind = request.args[b"kind"][0]
if kind == "guest": if kind == b"guest":
ret = yield self._do_guest_registration(body) ret = yield self._do_guest_registration(body)
defer.returnValue(ret) defer.returnValue(ret)
return return
elif kind != "user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
assert_params_in_dict(params, ["password"]) assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
new_password = params.get("password", None)
if desired_username is not None: if desired_username is not None:
desired_username = desired_username.lower() desired_username = desired_username.lower()

View file

@ -177,7 +177,7 @@ class MediaStorage(object):
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor()) open(local_path, "wb"), self.hs.get_reactor())
yield res.write_to_consumer(consumer) yield res.write_to_consumer(consumer)
yield consumer.wait() yield consumer.wait()
defer.returnValue(local_path) defer.returnValue(local_path)

View file

@ -577,7 +577,7 @@ def _make_state_cache_entry(
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)

View file

@ -67,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
def encode_json(json_object): def encode_json(json_object):
return frozendict_json_encoder.encode(json_object) """
Encode a Python object as JSON and return it in a Unicode string.
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode('utf8')
return out
class _EventPeristenceQueue(object): class _EventPeristenceQueue(object):
@ -1058,7 +1064,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
metadata_json = encode_json( metadata_json = encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8") )
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
@ -1172,8 +1178,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": encode_json( "internal_metadata": encode_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
).decode("UTF-8"), ),
"json": encode_json(event_dict(event)).decode("UTF-8"), "json": encode_json(event_dict(event)),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],

View file

@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
txn (cursor): txn (cursor):
event_id (str): Id for the Event. event_id (str): Id for the Event.
Returns: Returns:
A dict of algorithm -> hash. A dict[unicode, bytes] of algorithm -> hash.
""" """
query = ( query = (
"SELECT algorithm, hash" "SELECT algorithm, hash"

View file

@ -137,7 +137,7 @@ class DomainSpecificString(
@classmethod @classmethod
def from_string(cls, s): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0] != cls.SIGIL: if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
)) ))

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import string_types from six import binary_type, text_type
from canonicaljson import json from canonicaljson import json
from frozendict import frozendict from frozendict import frozendict
@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict): if isinstance(o, frozendict):
return o return o
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:
@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)): if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()}) return dict({k: unfreeze(v) for k, v in o.items()})
if isinstance(o, string_types): if isinstance(o, (binary_type, text_type)):
return o return o
try: try:

View file

@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = "_test_token_" self.test_token = b"_test_token_"
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientIP.return_value = "131.111.8.42"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(
requester.user.to_string(),
masquerading_user_id.decode('utf8')
)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user,
ip_range_whitelist=None, ip_range_whitelist=None,
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [token] request.args[b"access_token"] = [token.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertEqual(UserID.from_string(USER_ID), requester.user)
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
# the token should *not* work now # the token should *not* work now
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [guest_tok] request.args[b"access_token"] = [guest_tok.encode('ascii')]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:

View file

@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
self.prefix = prefix self.prefix = prefix
def trigger_get(self, path): def trigger_get(self, path):
return self.trigger("GET", path, None) return self.trigger(b"GET", path, None)
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
headers = {} headers = {}
if federation_auth: if federation_auth:
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
except Exception: except Exception:
pass pass
if isinstance(path, bytes):
path = path.decode('utf8')
for (method, pattern, func) in self.callbacks: for (method, pattern, func) in self.callbacks:
if http_method != method: if http_method != method:
continue continue
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
if matcher: if matcher:
try: try:
args = [ args = [
urlparse.unquote(u).decode("UTF-8") urlparse.unquote(u)
for u in matcher.groups() for u in matcher.groups()
] ]