Convert auth handler to async/await (#7261)

This commit is contained in:
Patrick Cloke 2020-04-15 12:40:18 -04:00 committed by GitHub
parent 17a2433b0d
commit eed7c5b89e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 224 additions and 170 deletions

1
changelog.d/7261.misc Normal file
View file

@ -0,0 +1 @@
Convert auth handler to async/await.

View file

@ -18,14 +18,12 @@ import logging
import time
import unicodedata
import urllib.parse
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import attr
import bcrypt # type: ignore[import]
import pymacaroons
from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
@ -170,15 +168,14 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks
def validate_user_via_ui_auth(
async def validate_user_via_ui_auth(
self,
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
clientip: str,
description: str,
):
) -> dict:
"""
Checks that the user is who they claim to be, via a UI auth.
@ -199,7 +196,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
The parameters for this request (which may
have been given only in a previous call).
Raises:
@ -229,7 +226,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
result, params, _ = yield self.check_auth(
result, params, _ = await self.check_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@ -268,15 +265,14 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
@defer.inlineCallbacks
def check_auth(
async def check_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientip: str,
description: str,
):
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@ -306,8 +302,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
(creds, params, session_id).
A tuple of (creds, params, session_id).
'creds' contains the authenticated credentials of each stage.
@ -380,7 +375,7 @@ class AuthHandler(BaseHandler):
if "type" in authdict:
login_type = authdict["type"] # type: str
try:
result = yield self._check_auth_dict(authdict, clientip)
result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
@ -419,8 +414,9 @@ class AuthHandler(BaseHandler):
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
) -> bool:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@ -435,7 +431,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {}
creds = sess["creds"]
result = yield self.checkers[stagetype].check_auth(authdict, clientip)
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
@ -489,8 +485,9 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
) -> Union[Dict[str, Any], str]:
"""Attempt to validate the auth dict provided by a client
Args:
@ -498,7 +495,7 @@ class AuthHandler(BaseHandler):
clientip: IP address of the client
Returns:
Deferred: result of the stage verification.
Result of the stage verification.
Raises:
StoreError if there was a problem accessing the database
@ -508,7 +505,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker.check_auth(authdict, clientip=clientip)
res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@ -518,7 +515,7 @@ class AuthHandler(BaseHandler):
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
(canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@ -584,8 +581,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def get_access_token_for_user_id(
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
"""
@ -615,10 +611,10 @@ class AuthHandler(BaseHandler):
)
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(
await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms
)
@ -628,15 +624,14 @@ class AuthHandler(BaseHandler):
# device, so we double-check it here.
if device_id is not None:
try:
yield self.store.get_device(user_id, device_id)
await self.store.get_device(user_id, device_id)
except StoreError:
yield self.store.delete_access_token(access_token)
await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
return access_token
@defer.inlineCallbacks
def check_user_exists(self, user_id: str):
async def check_user_exists(self, user_id: str) -> Optional[str]:
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
@ -645,25 +640,25 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
The canonical_user_id, or None if zero or multiple matches
"""
res = yield self._find_user_id_and_pwd_hash(user_id)
res = await self._find_user_id_and_pwd_hash(user_id)
if res is not None:
return res[0]
return None
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id: str):
async def _find_user_id_and_pwd_hash(
self, user_id: str
) -> Optional[Tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
A 2-tuple of `(canonical_user_id, password_hash)` or `None`
if there is not exactly one match
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
user_infos = await self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos:
@ -696,8 +691,9 @@ class AuthHandler(BaseHandler):
"""
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, username: str, login_submission: Dict[str, Any]):
async def validate_login(
self, username: str, login_submission: Dict[str, Any]
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
@ -708,7 +704,7 @@ class AuthHandler(BaseHandler):
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
Deferred[str, func]: canonical user id, and optional callback
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
StoreError if there was a problem accessing the database
@ -737,7 +733,7 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
is_valid = yield provider.check_password(qualified_user_id, password)
is_valid = await provider.check_password(qualified_user_id, password)
if is_valid:
return qualified_user_id, None
@ -769,7 +765,7 @@ class AuthHandler(BaseHandler):
% (login_type, missing_fields),
)
result = yield provider.check_auth(username, login_type, login_dict)
result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@ -778,8 +774,8 @@ class AuthHandler(BaseHandler):
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
canonical_user_id = yield self._check_local_password(
qualified_user_id, password
canonical_user_id = await self._check_local_password(
qualified_user_id, password # type: ignore
)
if canonical_user_id:
@ -792,8 +788,9 @@ class AuthHandler(BaseHandler):
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium: str, address: str, password: str):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@ -802,9 +799,8 @@ class AuthHandler(BaseHandler):
password: The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
callback)`. If authentication is successful, `user_id` is a `str`
containing the authenticated, canonical user ID. `callback` is
A tuple of `(user_id, callback)`. If authentication is successful,
`user_id`is the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
@ -816,7 +812,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = yield provider.check_3pid_auth(medium, address, password)
result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@ -826,8 +822,7 @@ class AuthHandler(BaseHandler):
return None, None
@defer.inlineCallbacks
def _check_local_password(self, user_id: str, password: str):
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
@ -837,28 +832,26 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
password: the provided password
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
The canonical_user_id, or None if unknown user/bad password
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
lookupres = await self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
deactivated = yield self.store.get_user_deactivated_status(user_id)
deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
result = yield self.validate_hash(password, password_hash)
result = await self.validate_hash(password, password_hash)
if not result:
logger.warning("Failed password login for user %s", user_id)
return None
return user_id
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token: str):
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@ -868,26 +861,23 @@ class AuthHandler(BaseHandler):
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
return user_id
@defer.inlineCallbacks
def delete_access_token(self, access_token: str):
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
access_token: access token to be deleted
Returns:
Deferred
"""
user_info = yield self.auth.get_user_by_access_token(access_token)
yield self.store.delete_access_token(access_token)
user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
yield provider.on_logged_out(
await provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
@ -895,12 +885,11 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
def delete_access_tokens_for_user(
async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[str] = None,
@ -914,10 +903,8 @@ class AuthHandler(BaseHandler):
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
tokens_and_devices = await self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id
)
@ -925,17 +912,18 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
await self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@ -956,14 +944,13 @@ class AuthHandler(BaseHandler):
if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
def delete_threepid(
async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
):
) -> bool:
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
@ -976,7 +963,7 @@ class AuthHandler(BaseHandler):
identity server specified when binding (if known).
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
unbind API.
"""
@ -986,11 +973,11 @@ class AuthHandler(BaseHandler):
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
yield self.store.user_delete_threepid(user_id, medium, address)
await self.store.user_delete_threepid(user_id, medium, address)
return result
def _save_session(self, session: Dict[str, Any]) -> None:
@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler):
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
def hash(self, password: str):
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
Args:
password: Password to hash.
Returns:
Deferred(unicode): Hashed password.
Hashed password.
"""
def _do_hash():
@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
return defer_to_thread(self.hs.get_reactor(), _do_hash)
return await defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password: str, stored_hash: bytes):
async def validate_hash(
self, password: str, stored_hash: Union[bytes, str]
) -> bool:
"""Validates that self.hash(password) == stored_hash.
Args:
@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler):
stored_hash: Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
return defer.succeed(False)
return False
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""

View file

@ -338,9 +338,11 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
yield self._auth_handler.delete_access_tokens_for_user(
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@ -391,9 +393,11 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)

View file

@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self._auth_handler.hash(password)
password_hash = yield defer.ensureDeferred(
self._auth_handler.hash(password)
)
if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -540,9 +542,11 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
access_token = yield defer.ensureDeferred(
self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
)
return (device_id, access_token)
@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid")
return
yield self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
)
# And we add an email pusher for them by default, but only
@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler):
return None
raise
yield self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
yield defer.ensureDeferred(
self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
)

View file

@ -15,8 +15,6 @@
import logging
from typing import Optional
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(
async def set_password(
self,
user_id: str,
new_password: str,
@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password)
password_hash = await self._auth_handler.hash(new_password)
try:
yield self.store.user_set_password_hash(user_id, password_hash)
await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler):
except_access_token_id = requester.access_token_id if requester else None
# First delete all of their other devices.
yield self._device_handler.delete_all_devices_for_user(
await self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
await self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)

View file

@ -86,7 +86,7 @@ class ModuleApi(object):
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._auth_handler.check_user_exists(user_id)
return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]):
@ -196,7 +196,9 @@ class ModuleApi(object):
yield self._hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
yield defer.ensureDeferred(
self._auth_handler.delete_access_token(access_token)
)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection

View file

@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield self.auth.get_user_by_access_token(serialized)
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@ -260,11 +264,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock()
self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
self.store.get_device = Mock(return_value=defer.succeed(None))
token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
token = yield defer.ensureDeferred(
self.hs.handlers.auth_handler.get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
)
@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True)
yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True
@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.BOT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self):
@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid]
yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
yield defer.ensureDeferred(
self.auth.check_auth_blocking(threepid=unknown_threepid)
)
yield self.auth.check_auth_blocking(threepid=threepid)
yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
user = "@user:server"
self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled"
yield self.auth.check_auth_blocking(user)
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))

View file

@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
user_id = yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@defer.inlineCallbacks
@ -99,9 +99,11 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
user_id = yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
)
self.assertEqual("a_user", user_id)
# add another "user_id" caveat, which might allow us to override the
@ -109,21 +111,27 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
)
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id(
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
@ -133,17 +141,21 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id(
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
@defer.inlineCallbacks
def test_mau_limits_parity(self):
@ -154,17 +166,21 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id(
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
@ -172,18 +188,22 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.get_access_token_for_user_id(
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
@ -193,16 +213,20 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id(
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)

View file

@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
create_profile_with_displayname=user.localpart,
)
else:
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield defer.ensureDeferred(
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
yield self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None

View file

@ -332,10 +332,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().validate_hash = (
lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
)
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed: