0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 06:43:46 +01:00

Convert auth handler to async/await (#7261)

This commit is contained in:
Patrick Cloke 2020-04-15 12:40:18 -04:00 committed by GitHub
parent 17a2433b0d
commit eed7c5b89e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 224 additions and 170 deletions

1
changelog.d/7261.misc Normal file
View file

@ -0,0 +1 @@
Convert auth handler to async/await.

View file

@ -18,14 +18,12 @@ import logging
import time import time
import unicodedata import unicodedata
import urllib.parse import urllib.parse
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import attr import attr
import bcrypt # type: ignore[import] import bcrypt # type: ignore[import]
import pymacaroons import pymacaroons
from twisted.internet import defer
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
@ -170,15 +168,14 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith # cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks async def validate_user_via_ui_auth(
def validate_user_via_ui_auth(
self, self,
requester: Requester, requester: Requester,
request: SynapseRequest, request: SynapseRequest,
request_body: Dict[str, Any], request_body: Dict[str, Any],
clientip: str, clientip: str,
description: str, description: str,
): ) -> dict:
""" """
Checks that the user is who they claim to be, via a UI auth. Checks that the user is who they claim to be, via a UI auth.
@ -199,7 +196,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account. describes the operation happening on their account.
Returns: Returns:
defer.Deferred[dict]: the parameters for this request (which may The parameters for this request (which may
have been given only in a previous call). have been given only in a previous call).
Raises: Raises:
@ -229,7 +226,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types] flows = [[login_type] for login_type in self._supported_ui_auth_types]
try: try:
result, params, _ = yield self.check_auth( result, params, _ = await self.check_auth(
flows, request, request_body, clientip, description flows, request, request_body, clientip, description
) )
except LoginError: except LoginError:
@ -268,15 +265,14 @@ class AuthHandler(BaseHandler):
""" """
return self.checkers.keys() return self.checkers.keys()
@defer.inlineCallbacks async def check_auth(
def check_auth(
self, self,
flows: List[List[str]], flows: List[List[str]],
request: SynapseRequest, request: SynapseRequest,
clientdict: Dict[str, Any], clientdict: Dict[str, Any],
clientip: str, clientip: str,
description: str, description: str,
): ) -> Tuple[dict, dict, str]:
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow. protocol and handles the User-Interactive Auth flow.
@ -306,8 +302,7 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account. describes the operation happening on their account.
Returns: Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of A tuple of (creds, params, session_id).
(creds, params, session_id).
'creds' contains the authenticated credentials of each stage. 'creds' contains the authenticated credentials of each stage.
@ -380,7 +375,7 @@ class AuthHandler(BaseHandler):
if "type" in authdict: if "type" in authdict:
login_type = authdict["type"] # type: str login_type = authdict["type"] # type: str
try: try:
result = yield self._check_auth_dict(authdict, clientip) result = await self._check_auth_dict(authdict, clientip)
if result: if result:
creds[login_type] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
@ -419,8 +414,9 @@ class AuthHandler(BaseHandler):
ret.update(errordict) ret.update(errordict)
raise InteractiveAuthIncompleteError(ret) raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks async def add_oob_auth(
def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): self, stagetype: str, authdict: Dict[str, Any], clientip: str
) -> bool:
""" """
Adds the result of out-of-band authentication into an existing auth Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth. session. Currently used for adding the result of fallback auth.
@ -435,7 +431,7 @@ class AuthHandler(BaseHandler):
sess["creds"] = {} sess["creds"] = {}
creds = sess["creds"] creds = sess["creds"]
result = yield self.checkers[stagetype].check_auth(authdict, clientip) result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result: if result:
creds[stagetype] = result creds[stagetype] = result
self._save_session(sess) self._save_session(sess)
@ -489,8 +485,9 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default) return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks async def _check_auth_dict(
def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): self, authdict: Dict[str, Any], clientip: str
) -> Union[Dict[str, Any], str]:
"""Attempt to validate the auth dict provided by a client """Attempt to validate the auth dict provided by a client
Args: Args:
@ -498,7 +495,7 @@ class AuthHandler(BaseHandler):
clientip: IP address of the client clientip: IP address of the client
Returns: Returns:
Deferred: result of the stage verification. Result of the stage verification.
Raises: Raises:
StoreError if there was a problem accessing the database StoreError if there was a problem accessing the database
@ -508,7 +505,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"] login_type = authdict["type"]
checker = self.checkers.get(login_type) checker = self.checkers.get(login_type)
if checker is not None: if checker is not None:
res = yield checker.check_auth(authdict, clientip=clientip) res = await checker.check_auth(authdict, clientip=clientip)
return res return res
# build a v1-login-style dict out of the authdict and fall back to the # build a v1-login-style dict out of the authdict and fall back to the
@ -518,7 +515,7 @@ class AuthHandler(BaseHandler):
if user_id is None: if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict) (canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id return canonical_id
def _get_params_recaptcha(self) -> dict: def _get_params_recaptcha(self) -> dict:
@ -584,8 +581,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks async def get_access_token_for_user_id(
def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
): ):
""" """
@ -615,10 +611,10 @@ class AuthHandler(BaseHandler):
) )
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
yield self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms user_id, access_token, device_id, valid_until_ms
) )
@ -628,15 +624,14 @@ class AuthHandler(BaseHandler):
# device, so we double-check it here. # device, so we double-check it here.
if device_id is not None: if device_id is not None:
try: try:
yield self.store.get_device(user_id, device_id) await self.store.get_device(user_id, device_id)
except StoreError: except StoreError:
yield self.store.delete_access_token(access_token) await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion") raise StoreError(400, "Login raced against device deletion")
return access_token return access_token
@defer.inlineCallbacks async def check_user_exists(self, user_id: str) -> Optional[str]:
def check_user_exists(self, user_id: str):
""" """
Checks to see if a user with the given id exists. Will check case Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches. insensitively, but return None if there are multiple inexact matches.
@ -645,25 +640,25 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id user_id: complete @user:id
Returns: Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or The canonical_user_id, or None if zero or multiple matches
multiple matches
""" """
res = yield self._find_user_id_and_pwd_hash(user_id) res = await self._find_user_id_and_pwd_hash(user_id)
if res is not None: if res is not None:
return res[0] return res[0]
return None return None
@defer.inlineCallbacks async def _find_user_id_and_pwd_hash(
def _find_user_id_and_pwd_hash(self, user_id: str): self, user_id: str
) -> Optional[Tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact insensitively, but will return None if there are multiple inexact
matches. matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` A 2-tuple of `(canonical_user_id, password_hash)` or `None`
None: if there is not exactly one match if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = await self.store.get_users_by_id_case_insensitive(user_id)
result = None result = None
if not user_infos: if not user_infos:
@ -696,8 +691,9 @@ class AuthHandler(BaseHandler):
""" """
return self._supported_login_types return self._supported_login_types
@defer.inlineCallbacks async def validate_login(
def validate_login(self, username: str, login_submission: Dict[str, Any]): self, username: str, login_submission: Dict[str, Any]
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate Also used by the user-interactive auth flow to validate
@ -708,7 +704,7 @@ class AuthHandler(BaseHandler):
login_submission: the whole of the login submission login_submission: the whole of the login submission
(including 'type' and other relevant fields) (including 'type' and other relevant fields)
Returns: Returns:
Deferred[str, func]: canonical user id, and optional callback A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued to be called once the access token and device id are issued
Raises: Raises:
StoreError if there was a problem accessing the database StoreError if there was a problem accessing the database
@ -737,7 +733,7 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True known_login_type = True
is_valid = yield provider.check_password(qualified_user_id, password) is_valid = await provider.check_password(qualified_user_id, password)
if is_valid: if is_valid:
return qualified_user_id, None return qualified_user_id, None
@ -769,7 +765,7 @@ class AuthHandler(BaseHandler):
% (login_type, missing_fields), % (login_type, missing_fields),
) )
result = yield provider.check_auth(username, login_type, login_dict) result = await provider.check_auth(username, login_type, login_dict)
if result: if result:
if isinstance(result, str): if isinstance(result, str):
result = (result, None) result = (result, None)
@ -778,8 +774,8 @@ class AuthHandler(BaseHandler):
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True known_login_type = True
canonical_user_id = yield self._check_local_password( canonical_user_id = await self._check_local_password(
qualified_user_id, password qualified_user_id, password # type: ignore
) )
if canonical_user_id: if canonical_user_id:
@ -792,8 +788,9 @@ class AuthHandler(BaseHandler):
# login, it turns all LoginErrors into a 401 anyway. # login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks async def check_password_provider_3pid(
def check_password_provider_3pid(self, medium: str, address: str, password: str): self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:
@ -802,9 +799,8 @@ class AuthHandler(BaseHandler):
password: The password of the user. password: The password of the user.
Returns: Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id, A tuple of `(user_id, callback)`. If authentication is successful,
callback)`. If authentication is successful, `user_id` is a `str` `user_id`is the authenticated, canonical user ID. `callback` is
containing the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`. unsuccessful, `user_id` and `callback` are both `None`.
@ -816,7 +812,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of # success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run # (user_id, callback_func), where callback_func should be run
# after we've finished everything else # after we've finished everything else
result = yield provider.check_3pid_auth(medium, address, password) result = await provider.check_3pid_auth(medium, address, password)
if result: if result:
# Check if the return value is a str or a tuple # Check if the return value is a str or a tuple
if isinstance(result, str): if isinstance(result, str):
@ -826,8 +822,7 @@ class AuthHandler(BaseHandler):
return None, None return None, None
@defer.inlineCallbacks async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are user_id is checked case insensitively, but will return None if there are
@ -837,28 +832,26 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id user_id: complete @user:id
password: the provided password password: the provided password
Returns: Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if The canonical_user_id, or None if unknown user/bad password
unknown user/bad password
""" """
lookupres = yield self._find_user_id_and_pwd_hash(user_id) lookupres = await self._find_user_id_and_pwd_hash(user_id)
if not lookupres: if not lookupres:
return None return None
(user_id, password_hash) = lookupres (user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated # If the password hash is None, the account has likely been deactivated
if not password_hash: if not password_hash:
deactivated = yield self.store.get_user_deactivated_status(user_id) deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated: if deactivated:
raise UserDeactivatedError("This account has been deactivated") raise UserDeactivatedError("This account has been deactivated")
result = yield self.validate_hash(password, password_hash) result = await self.validate_hash(password, password_hash)
if not result: if not result:
logger.warning("Failed password login for user %s", user_id) logger.warning("Failed password login for user %s", user_id)
return None return None
return user_id return user_id
@defer.inlineCallbacks async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -868,26 +861,23 @@ class AuthHandler(BaseHandler):
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
return user_id return user_id
@defer.inlineCallbacks async def delete_access_token(self, access_token: str):
def delete_access_token(self, access_token: str):
"""Invalidate a single access token """Invalidate a single access token
Args: Args:
access_token: access token to be deleted access_token: access token to be deleted
Returns:
Deferred
""" """
user_info = yield self.auth.get_user_by_access_token(access_token) user_info = await self.auth.get_user_by_access_token(access_token)
yield self.store.delete_access_token(access_token) await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): if hasattr(provider, "on_logged_out"):
yield provider.on_logged_out( await provider.on_logged_out(
user_id=str(user_info["user"]), user_id=str(user_info["user"]),
device_id=user_info["device_id"], device_id=user_info["device_id"],
access_token=access_token, access_token=access_token,
@ -895,12 +885,11 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info["token_id"] is not None: if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],) str(user_info["user"]), (user_info["token_id"],)
) )
@defer.inlineCallbacks async def delete_access_tokens_for_user(
def delete_access_tokens_for_user(
self, self,
user_id: str, user_id: str,
except_token_id: Optional[str] = None, except_token_id: Optional[str] = None,
@ -914,10 +903,8 @@ class AuthHandler(BaseHandler):
device_id: ID of device the tokens are associated with. device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns:
Deferred
""" """
tokens_and_devices = yield self.store.user_delete_access_tokens( tokens_and_devices = await self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id user_id, except_token_id=except_token_id, device_id=device_id
) )
@ -925,17 +912,18 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices: for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out( await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token user_id=user_id, device_id=device_id, access_token=token
) )
# delete pushers associated with the access tokens # delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices) user_id, (token_id for _, token_id, _ in tokens_and_devices)
) )
@defer.inlineCallbacks async def add_threepid(
def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): self, user_id: str, medium: str, address: str, validated_at: int
):
# check if medium has a valid value # check if medium has a valid value
if medium not in ["email", "msisdn"]: if medium not in ["email", "msisdn"]:
raise SynapseError( raise SynapseError(
@ -956,14 +944,13 @@ class AuthHandler(BaseHandler):
if medium == "email": if medium == "email":
address = address.lower() address = address.lower()
yield self.store.user_add_threepid( await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec() user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
) )
@defer.inlineCallbacks async def delete_threepid(
def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
): ) -> bool:
"""Attempts to unbind the 3pid on the identity servers and deletes it """Attempts to unbind the 3pid on the identity servers and deletes it
from the local database. from the local database.
@ -976,7 +963,7 @@ class AuthHandler(BaseHandler):
identity server specified when binding (if known). identity server specified when binding (if known).
Returns: Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the the identity server, False if identity server doesn't support the
unbind API. unbind API.
""" """
@ -986,11 +973,11 @@ class AuthHandler(BaseHandler):
address = address.lower() address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid( result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server} user_id, {"medium": medium, "address": address, "id_server": id_server}
) )
yield self.store.user_delete_threepid(user_id, medium, address) await self.store.user_delete_threepid(user_id, medium, address)
return result return result
def _save_session(self, session: Dict[str, Any]) -> None: def _save_session(self, session: Dict[str, Any]) -> None:
@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler):
session["last_used"] = self.hs.get_clock().time_msec() session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
def hash(self, password: str): async def hash(self, password: str) -> str:
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password: Password to hash. password: Password to hash.
Returns: Returns:
Deferred(unicode): Hashed password. Hashed password.
""" """
def _do_hash(): def _do_hash():
@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds), bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii") ).decode("ascii")
return defer_to_thread(self.hs.get_reactor(), _do_hash) return await defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password: str, stored_hash: bytes): async def validate_hash(
self, password: str, stored_hash: Union[bytes, str]
) -> bool:
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler):
stored_hash: Expected hash value. stored_hash: Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash():
@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes): if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii") stored_hash = stored_hash.encode("ascii")
return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else: else:
return defer.succeed(False) return False
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
""" """

View file

@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler):
else: else:
raise raise
yield self._auth_handler.delete_access_tokens_for_user( yield defer.ensureDeferred(
user_id, device_id=device_id self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
) )
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not # Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path. # considered as part of a critical path.
for device_id in device_ids: for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user( yield defer.ensureDeferred(
user_id, device_id=device_id self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
) )
yield self.store.delete_e2e_keys_by_device( yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id

View file

@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid) yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None password_hash = None
if password: if password:
password_hash = yield self._auth_handler.hash(password) password_hash = yield defer.ensureDeferred(
self._auth_handler.hash(password)
)
if localpart is not None: if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"] user_id, ["guest = true"]
) )
else: else:
access_token = yield self._auth_handler.get_access_token_for_user_id( access_token = yield defer.ensureDeferred(
user_id, device_id=device_id, valid_until_ms=valid_until_ms self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
) )
return (device_id, access_token) return (device_id, access_token)
@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
return return
yield self._auth_handler.add_threepid( yield defer.ensureDeferred(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"] self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
) )
# And we add an email pusher for them by default, but only # And we add an email pusher for them by default, but only
@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler):
return None return None
raise raise
yield self._auth_handler.add_threepid( yield defer.ensureDeferred(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"] self._auth_handler.add_threepid(
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
) )

View file

@ -15,8 +15,6 @@
import logging import logging
from typing import Optional from typing import Optional
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester from synapse.types import Requester
@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler() self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks async def set_password(
def set_password(
self, self,
user_id: str, user_id: str,
new_password: str, new_password: str,
@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
self._password_policy_handler.validate_password(new_password) self._password_policy_handler.validate_password(new_password)
password_hash = yield self._auth_handler.hash(new_password) password_hash = await self._auth_handler.hash(new_password)
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler):
except_access_token_id = requester.access_token_id if requester else None except_access_token_id = requester.access_token_id if requester else None
# First delete all of their other devices. # First delete all of their other devices.
yield self._device_handler.delete_all_devices_for_user( await self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id user_id, except_device_id=except_device_id
) )
# and now delete any access tokens which weren't associated with # and now delete any access tokens which weren't associated with
# devices (or were associated with this device). # devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user( await self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id user_id, except_token_id=except_access_token_id
) )

View file

@ -86,7 +86,7 @@ class ModuleApi(object):
Deferred[str|None]: Canonical (case-corrected) user_id, or None Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered. if the user is not registered.
""" """
return self._auth_handler.check_user_exists(user_id) return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]): def register(self, localpart, displayname=None, emails=[]):
@ -196,7 +196,9 @@ class ModuleApi(object):
yield self._hs.get_device_handler().delete_device(user_id, device_id) yield self._hs.get_device_handler().delete_device(user_id, device_id)
else: else:
# no associated device. Just delete the access token. # no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token) yield defer.ensureDeferred(
self._auth_handler.delete_access_token(access_token)
)
def run_db_interaction(self, desc, func, *args, **kwargs): def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection """Run a function with a database connection

View file

@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1" request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10" request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self): def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals( self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8") requester.user.to_string(), masquerading_user_id.decode("utf8")
) )
@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true") macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize() serialized = macaroon.serialize()
user_info = yield self.auth.get_user_by_access_token(serialized) user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user = user_info["user"] user = user_info["user"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self): def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org" USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock() self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
self.store.get_device = Mock(return_value=defer.succeed(None))
token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( token = yield defer.ensureDeferred(
USER_ID, "DEVICE", valid_until_ms=None self.hs.handlers.auth_handler.get_access_token_for_user_id(
USER_ID, "DEVICE", valid_until_ms=None
)
) )
self.store.add_access_token_to_user.assert_called_with( self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None USER_ID, token, "DEVICE", None
@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")] request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest) self.assertFalse(requester.is_guest)
@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm: with self.assertRaises(InvalidClientCredentialsError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True) yield defer.ensureDeferred(
self.auth.get_user_by_req(request, allow_guest=True)
)
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1 small_number_of_users = 1
# Ensure no error thrown # Ensure no error thrown
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
) )
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)
@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users) return_value=defer.succeed(small_number_of_users)
) )
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks @defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self): def test_blocking_mau__depending_on_user_type(self):
@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed # Support users allowed
yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed # Bots not allowed
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) yield defer.ensureDeferred(
self.auth.check_auth_blocking(user_type=UserTypes.BOT)
)
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed # Real users not allowed
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks @defer.inlineCallbacks
def test_reserved_threepid(self): def test_reserved_threepid(self):
@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.hs.config.mau_limits_reserved_threepids = [threepid]
yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(threepid=unknown_threepid) yield defer.ensureDeferred(
self.auth.check_auth_blocking(threepid=unknown_threepid)
)
yield self.auth.check_auth_blocking(threepid=threepid) yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hs_disabled(self): def test_hs_disabled(self):
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)
@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)
@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
user = "@user:server" user = "@user:server"
self.hs.config.server_notices_mxid = user self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
yield self.auth.check_auth_blocking(user) yield defer.ensureDeferred(self.auth.check_auth_blocking(user))

View file

@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield defer.ensureDeferred(
token self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
self.assertEqual("a_user", user_id) self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.clock.now = 6000 self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
token self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield defer.ensureDeferred(
macaroon.serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
) )
self.assertEqual("a_user", user_id) self.assertEqual("a_user", user_id)
@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user") macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
macaroon.serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_disabled(self): def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id( yield defer.ensureDeferred(
"user_a", device_id=None, valid_until_ms=None self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
) )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
self._get_macaroon().serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase):
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id( yield defer.ensureDeferred(
"user_a", device_id=None, valid_until_ms=None self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
self._get_macaroon().serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id( yield defer.ensureDeferred(
"user_a", device_id=None, valid_until_ms=None self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
self._get_macaroon().serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
) )
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
yield self.auth_handler.get_access_token_for_user_id( yield defer.ensureDeferred(
"user_a", device_id=None, valid_until_ms=None self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
) )
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=defer.succeed(self.hs.get_clock().time_msec())
@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
self._get_macaroon().serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id( yield defer.ensureDeferred(
"user_a", device_id=None, valid_until_ms=None self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
) )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield defer.ensureDeferred(
self._get_macaroon().serialize() self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
) )
def _get_macaroon(self): def _get_macaroon(self):

View file

@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
create_profile_with_displayname=user.localpart, create_profile_with_displayname=user.localpart,
) )
else: else:
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) yield defer.ensureDeferred(
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
yield self.store.add_access_token_to_user( yield self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None user_id=user_id, token=token, device_id=None, valid_until_ms=None

View file

@ -332,10 +332,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it # Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one # because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg) # beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() async def hash(p):
hs.get_auth_handler().validate_hash = ( return hashlib.md5(p.encode("utf8")).hexdigest()
lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
) hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None) fed = kargs.get("resource_for_federation", None)
if fed: if fed: