forked from MirrorHub/synapse
Python 3: Convert some unicode/bytes uses (#3569)
This commit is contained in:
parent
c4842e16cb
commit
da7785147d
17 changed files with 122 additions and 67 deletions
1
changelog.d/3569.bugfix
Normal file
1
changelog.d/3569.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password.
|
|
@ -252,10 +252,10 @@ class Auth(object):
|
||||||
if ip_address not in app_service.ip_range_whitelist:
|
if ip_address not in app_service.ip_range_whitelist:
|
||||||
defer.returnValue((None, None))
|
defer.returnValue((None, None))
|
||||||
|
|
||||||
if "user_id" not in request.args:
|
if b"user_id" not in request.args:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
user_id = request.args["user_id"][0]
|
user_id = request.args[b"user_id"][0].decode('utf8')
|
||||||
if app_service.sender == user_id:
|
if app_service.sender == user_id:
|
||||||
defer.returnValue((app_service.sender, app_service))
|
defer.returnValue((app_service.sender, app_service))
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
|
||||||
param_dict = dict(kv.split("=") for kv in params)
|
param_dict = dict(kv.split("=") for kv in params)
|
||||||
|
|
||||||
def strip_quotes(value):
|
def strip_quotes(value):
|
||||||
if value.startswith(b"\""):
|
if value.startswith("\""):
|
||||||
return value[1:-1]
|
return value[1:-1]
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
@ -626,6 +627,7 @@ class AuthHandler(BaseHandler):
|
||||||
# special case to check for "password" for the check_password interface
|
# special case to check for "password" for the check_password interface
|
||||||
# for the auth providers
|
# for the auth providers
|
||||||
password = login_submission.get("password")
|
password = login_submission.get("password")
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD:
|
if login_type == LoginType.PASSWORD:
|
||||||
if not self._password_enabled:
|
if not self._password_enabled:
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
@ -707,9 +709,10 @@ class AuthHandler(BaseHandler):
|
||||||
multiple inexact matches.
|
multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
user_id (unicode): complete @user:id
|
||||||
|
password (unicode): the provided password
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id, or None if unknown user / bad password
|
(unicode) the canonical_user_id, or None if unknown user / bad password
|
||||||
"""
|
"""
|
||||||
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
if not lookupres:
|
if not lookupres:
|
||||||
|
@ -849,14 +852,19 @@ class AuthHandler(BaseHandler):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(str): Hashed password.
|
Deferred(unicode): Hashed password.
|
||||||
"""
|
"""
|
||||||
def _do_hash():
|
def _do_hash():
|
||||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
# Normalise the Unicode in the password
|
||||||
bcrypt.gensalt(self.bcrypt_rounds))
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
|
return bcrypt.hashpw(
|
||||||
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds),
|
||||||
|
).decode('ascii')
|
||||||
|
|
||||||
return make_deferred_yieldable(
|
return make_deferred_yieldable(
|
||||||
threads.deferToThreadPool(
|
threads.deferToThreadPool(
|
||||||
|
@ -868,16 +876,19 @@ class AuthHandler(BaseHandler):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (str): Password to hash.
|
password (unicode): Password to hash.
|
||||||
stored_hash (str): Expected hash value.
|
stored_hash (unicode): Expected hash value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _do_validate_hash():
|
def _do_validate_hash():
|
||||||
|
# Normalise the Unicode in the password
|
||||||
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
password.encode('utf8') + self.hs.config.password_pepper,
|
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
|
||||||
stored_hash.encode('utf8')
|
stored_hash.encode('utf8')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Args:
|
Args:
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (unicode) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
generate_token (bool): Whether a new access token should be
|
generate_token (bool): Whether a new access token should be
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
|
||||||
|
|
||||||
from six.moves import http_client
|
from six import PY3
|
||||||
|
from six.moves import http_client, urllib
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||||
|
|
||||||
|
@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
|
method = method.encode("utf-8") # method is bytes on py3
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
|
@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
# here. If it throws an exception, that is handled by the wrapper
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
# installed by @request_handler.
|
# installed by @request_handler.
|
||||||
|
|
||||||
|
def _unquote(s):
|
||||||
|
if PY3:
|
||||||
|
# On Python 3, unquote is unicode -> unicode
|
||||||
|
return urllib.parse.unquote(s)
|
||||||
|
else:
|
||||||
|
# On Python 2, unquote is bytes -> bytes We need to encode the
|
||||||
|
# URL again (as it was decoded by _get_handler_for request), as
|
||||||
|
# ASCII because it's a URL, and then decode it to get the UTF-8
|
||||||
|
# characters that were quoted.
|
||||||
|
return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
|
||||||
|
|
||||||
kwargs = intern_dict({
|
kwargs = intern_dict({
|
||||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
name: _unquote(value) if value else value
|
||||||
for name, value in group_dict.items()
|
for name, value in group_dict.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Callable, dict[str, str]]: callback method, and the dict
|
Tuple[Callable, dict[unicode, unicode]]: callback method, and the
|
||||||
mapping keys to path components as specified in the handler's
|
dict mapping keys to path components as specified in the
|
||||||
path match regexp.
|
handler's path match regexp.
|
||||||
|
|
||||||
The callback will normally be a method registered via
|
The callback will normally be a method registered via
|
||||||
register_paths, so will return (possibly via Deferred) either
|
register_paths, so will return (possibly via Deferred) either
|
||||||
|
@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
m = path_entry.pattern.match(request.path)
|
m = path_entry.pattern.match(request.path.decode('ascii'))
|
||||||
if m:
|
if m:
|
||||||
# We found a match!
|
# We found a match!
|
||||||
return path_entry.callback, m.groupdict()
|
return path_entry.callback, m.groupdict()
|
||||||
|
@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
|
||||||
self.url = path
|
self.url = path
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
return redirectTo(self.url, request)
|
return redirectTo(self.url.encode('ascii'), request)
|
||||||
|
|
||||||
def getChild(self, name, request):
|
def getChild(self, name, request):
|
||||||
if len(name) == 0:
|
if len(name) == 0:
|
||||||
|
@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
|
||||||
return
|
return
|
||||||
|
|
||||||
if pretty_print:
|
if pretty_print:
|
||||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
|
||||||
|
).encode("utf-8")
|
||||||
else:
|
else:
|
||||||
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
if canonical_json or synapse.events.USE_FROZEN_DICTS:
|
||||||
|
# canonicaljson already encodes to bytes
|
||||||
json_bytes = encode_canonical_json(json_object)
|
json_bytes = encode_canonical_json(json_object)
|
||||||
else:
|
else:
|
||||||
json_bytes = json.dumps(json_object)
|
json_bytes = json.dumps(json_object).encode("utf-8")
|
||||||
|
|
||||||
return respond_with_json_bytes(
|
return respond_with_json_bytes(
|
||||||
request, code, json_bytes,
|
request, code, json_bytes,
|
||||||
|
|
|
@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||||
if not content_bytes and allow_empty_body:
|
if not content_bytes and allow_empty_body:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Decode to Unicode so that simplejson will return Unicode strings on
|
||||||
|
# Python 2
|
||||||
try:
|
try:
|
||||||
content = json.loads(content_bytes)
|
content_unicode = content_bytes.decode('utf8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warn("Unable to decode UTF-8")
|
||||||
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.loads(content_unicode)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Unable to parse JSON: %s", e)
|
logger.warn("Unable to parse JSON: %s", e)
|
||||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from six import text_type
|
||||||
from six.moves import http_client
|
from six.moves import http_client
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['username'], str) or len(body['username']) > 512):
|
if (
|
||||||
|
not isinstance(body['username'], text_type)
|
||||||
|
or len(body['username']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid username")
|
raise SynapseError(400, "Invalid username")
|
||||||
|
|
||||||
username = body["username"].encode("utf-8")
|
username = body["username"].encode("utf-8")
|
||||||
|
@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (not isinstance(body['password'], str) or len(body['password']) > 512):
|
if (
|
||||||
|
not isinstance(body['password'], text_type)
|
||||||
|
or len(body['password']) > 512
|
||||||
|
):
|
||||||
raise SynapseError(400, "Invalid password")
|
raise SynapseError(400, "Invalid password")
|
||||||
|
|
||||||
password = body["password"].encode("utf-8")
|
password = body["password"].encode("utf-8")
|
||||||
|
@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
|
||||||
want_mac.update(b"admin" if admin else b"notadmin")
|
want_mac.update(b"admin" if admin else b"notadmin")
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if not hmac.compare_digest(want_mac, got_mac):
|
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
|
||||||
raise SynapseError(
|
raise SynapseError(403, "HMAC incorrect")
|
||||||
403, "HMAC incorrect",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
|
|
||||||
register = RegisterRestServlet(self.hs)
|
register = RegisterRestServlet(self.hs)
|
||||||
|
|
||||||
(user_id, _) = yield register.registration_handler.register(
|
(user_id, _) = yield register.registration_handler.register(
|
||||||
localpart=username.lower(), password=password, admin=bool(admin),
|
localpart=body['username'].lower(),
|
||||||
|
password=body["password"],
|
||||||
|
admin=bool(admin),
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
kind = "user"
|
kind = b"user"
|
||||||
if "kind" in request.args:
|
if b"kind" in request.args:
|
||||||
kind = request.args["kind"][0]
|
kind = request.args[b"kind"][0]
|
||||||
|
|
||||||
if kind == "guest":
|
if kind == b"guest":
|
||||||
ret = yield self._do_guest_registration(body)
|
ret = yield self._do_guest_registration(body)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != "user":
|
elif kind != b"user":
|
||||||
raise UnrecognizedRequestError(
|
raise UnrecognizedRequestError(
|
||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
|
@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
|
||||||
assert_params_in_dict(params, ["password"])
|
assert_params_in_dict(params, ["password"])
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
new_password = params.get("password", None)
|
|
||||||
guest_access_token = params.get("guest_access_token", None)
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
new_password = params.get("password", None)
|
||||||
|
|
||||||
if desired_username is not None:
|
if desired_username is not None:
|
||||||
desired_username = desired_username.lower()
|
desired_username = desired_username.lower()
|
||||||
|
|
|
@ -177,7 +177,7 @@ class MediaStorage(object):
|
||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(
|
consumer = BackgroundFileConsumer(
|
||||||
open(local_path, "w"), self.hs.get_reactor())
|
open(local_path, "wb"), self.hs.get_reactor())
|
||||||
yield res.write_to_consumer(consumer)
|
yield res.write_to_consumer(consumer)
|
||||||
yield consumer.wait()
|
yield consumer.wait()
|
||||||
defer.returnValue(local_path)
|
defer.returnValue(local_path)
|
||||||
|
|
|
@ -577,7 +577,7 @@ def _make_state_cache_entry(
|
||||||
|
|
||||||
def _ordered_events(events):
|
def _ordered_events(events):
|
||||||
def key_func(e):
|
def key_func(e):
|
||||||
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
|
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||||
|
|
||||||
return sorted(events, key=key_func)
|
return sorted(events, key=key_func)
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
|
||||||
|
|
||||||
|
|
||||||
def encode_json(json_object):
|
def encode_json(json_object):
|
||||||
return frozendict_json_encoder.encode(json_object)
|
"""
|
||||||
|
Encode a Python object as JSON and return it in a Unicode string.
|
||||||
|
"""
|
||||||
|
out = frozendict_json_encoder.encode(json_object)
|
||||||
|
if isinstance(out, bytes):
|
||||||
|
out = out.decode('utf8')
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class _EventPeristenceQueue(object):
|
class _EventPeristenceQueue(object):
|
||||||
|
@ -1058,7 +1064,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||||
|
|
||||||
metadata_json = encode_json(
|
metadata_json = encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8")
|
)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
|
@ -1172,8 +1178,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"internal_metadata": encode_json(
|
"internal_metadata": encode_json(
|
||||||
event.internal_metadata.get_dict()
|
event.internal_metadata.get_dict()
|
||||||
).decode("UTF-8"),
|
),
|
||||||
"json": encode_json(event_dict(event)).decode("UTF-8"),
|
"json": encode_json(event_dict(event)),
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
],
|
],
|
||||||
|
|
|
@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||||
txn (cursor):
|
txn (cursor):
|
||||||
event_id (str): Id for the Event.
|
event_id (str): Id for the Event.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of algorithm -> hash.
|
A dict[unicode, bytes] of algorithm -> hash.
|
||||||
"""
|
"""
|
||||||
query = (
|
query = (
|
||||||
"SELECT algorithm, hash"
|
"SELECT algorithm, hash"
|
||||||
|
|
|
@ -137,7 +137,7 @@ class DomainSpecificString(
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, s):
|
def from_string(cls, s):
|
||||||
"""Parse the string given by 's' into a structure object."""
|
"""Parse the string given by 's' into a structure object."""
|
||||||
if len(s) < 1 or s[0] != cls.SIGIL:
|
if len(s) < 1 or s[0:1] != cls.SIGIL:
|
||||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||||
cls.__name__, cls.SIGIL,
|
cls.__name__, cls.SIGIL,
|
||||||
))
|
))
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from six import string_types
|
from six import binary_type, text_type
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
@ -26,7 +26,7 @@ def freeze(o):
|
||||||
if isinstance(o, frozendict):
|
if isinstance(o, frozendict):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -41,7 +41,7 @@ def unfreeze(o):
|
||||||
if isinstance(o, (dict, frozendict)):
|
if isinstance(o, (dict, frozendict)):
|
||||||
return dict({k: unfreeze(v) for k, v in o.items()})
|
return dict({k: unfreeze(v) for k, v in o.items()})
|
||||||
|
|
||||||
if isinstance(o, string_types):
|
if isinstance(o, (binary_type, text_type)):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.auth = Auth(self.hs)
|
self.auth = Auth(self.hs)
|
||||||
|
|
||||||
self.test_user = "@foo:bar"
|
self.test_user = "@foo:bar"
|
||||||
self.test_token = "_test_token_"
|
self.test_token = b"_test_token_"
|
||||||
|
|
||||||
# this is overridden for the appservice tests
|
# this is overridden for the appservice tests
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientIP.return_value = "192.168.10.10"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientIP.return_value = "131.111.8.42"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
|
@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
self.assertEquals(
|
||||||
|
requester.user.to_string(),
|
||||||
|
masquerading_user_id.decode('utf8')
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user,
|
token="foobar", url="a_url", sender=self.test_user,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist=None,
|
||||||
|
@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = self.auth.get_user_by_req(request)
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# check the token works
|
# check the token works
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [token]
|
request.args[b"access_token"] = [token.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
self.assertEqual(UserID.from_string(USER_ID), requester.user)
|
||||||
|
@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# the token should *not* work now
|
# the token should *not* work now
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [guest_tok]
|
request.args[b"access_token"] = [guest_tok.encode('ascii')]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
|
||||||
with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
|
|
|
@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
def trigger_get(self, path):
|
def trigger_get(self, path):
|
||||||
return self.trigger("GET", path, None)
|
return self.trigger(b"GET", path, None)
|
||||||
|
|
||||||
@patch('twisted.web.http.Request')
|
@patch('twisted.web.http.Request')
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if federation_auth:
|
if federation_auth:
|
||||||
headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
|
headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
|
||||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
||||||
|
|
||||||
# return the right path if the event requires it
|
# return the right path if the event requires it
|
||||||
|
@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
path = path.decode('utf8')
|
||||||
|
|
||||||
for (method, pattern, func) in self.callbacks:
|
for (method, pattern, func) in self.callbacks:
|
||||||
if http_method != method:
|
if http_method != method:
|
||||||
continue
|
continue
|
||||||
|
@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
|
||||||
if matcher:
|
if matcher:
|
||||||
try:
|
try:
|
||||||
args = [
|
args = [
|
||||||
urlparse.unquote(u).decode("UTF-8")
|
urlparse.unquote(u)
|
||||||
for u in matcher.groups()
|
for u in matcher.groups()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue