forked from MirrorHub/synapse
Add type annotations and comments to auth handler (#7063)
This commit is contained in:
parent
bd5e555b0d
commit
77d0a4507b
3 changed files with 106 additions and 89 deletions
1
changelog.d/7063.misc
Normal file
1
changelog.d/7063.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations and comments to the auth handler.
|
|
@ -18,10 +18,10 @@ import logging
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import bcrypt
|
import bcrypt # type: ignore[import]
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.push.mailer import load_jinja2_templates
|
from synapse.push.mailer import load_jinja2_templates
|
||||||
from synapse.types import UserID
|
from synapse.types import Requester, UserID
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
|
|
||||||
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
|
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
||||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||||
inst = auth_checker_class(hs)
|
inst = auth_checker_class(hs)
|
||||||
if inst.is_enabled():
|
if inst.is_enabled():
|
||||||
self.checkers[inst.AUTH_TYPE] = inst
|
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
|
||||||
|
|
||||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||||
|
|
||||||
|
@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
|
||||||
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
def validate_user_via_ui_auth(
|
||||||
|
self, requester: Requester, request_body: Dict[str, Any], clientip: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Checks that the user is who they claim to be, via a UI auth.
|
Checks that the user is who they claim to be, via a UI auth.
|
||||||
|
|
||||||
|
@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
|
||||||
that it isn't stolen by re-authenticating them.
|
that it isn't stolen by re-authenticating them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester (Requester): The user, as given by the access token
|
requester: The user, as given by the access token
|
||||||
|
|
||||||
request_body (dict): The body of the request sent by the client
|
request_body: The body of the request sent by the client
|
||||||
|
|
||||||
clientip (str): The IP address of the client.
|
clientip: The IP address of the client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[dict]: the parameters for this request (which may
|
defer.Deferred[dict]: the parameters for this request (which may
|
||||||
|
@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
|
||||||
return self.checkers.keys()
|
return self.checkers.keys()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(
|
||||||
|
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Takes a dictionary sent by the client in the login / registration
|
Takes a dictionary sent by the client in the login / registration
|
||||||
protocol and handles the User-Interactive Auth flow.
|
protocol and handles the User-Interactive Auth flow.
|
||||||
|
@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
|
||||||
decorator.
|
decorator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flows (list): A list of login flows. Each flow is an ordered list of
|
flows: A list of login flows. Each flow is an ordered list of
|
||||||
strings representing auth-types. At least one full
|
strings representing auth-types. At least one full
|
||||||
flow must be completed in order for auth to be successful.
|
flow must be completed in order for auth to be successful.
|
||||||
|
|
||||||
clientdict: The dictionary from the client root level, not the
|
clientdict: The dictionary from the client root level, not the
|
||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
|
|
||||||
clientip (str): The IP address of the client.
|
clientip: The IP address of the client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[dict, dict, str]: a deferred tuple of
|
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||||
|
@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authdict = None
|
authdict = None
|
||||||
sid = None
|
sid = None # type: Optional[str]
|
||||||
if clientdict and "auth" in clientdict:
|
if clientdict and "auth" in clientdict:
|
||||||
authdict = clientdict["auth"]
|
authdict = clientdict["auth"]
|
||||||
del clientdict["auth"]
|
del clientdict["auth"]
|
||||||
|
@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
|
||||||
creds = session["creds"]
|
creds = session["creds"]
|
||||||
|
|
||||||
# check auth type currently being presented
|
# check auth type currently being presented
|
||||||
errordict = {}
|
errordict = {} # type: Dict[str, Any]
|
||||||
if "type" in authdict:
|
if "type" in authdict:
|
||||||
login_type = authdict["type"]
|
login_type = authdict["type"] # type: str
|
||||||
try:
|
try:
|
||||||
result = yield self._check_auth_dict(authdict, clientip)
|
result = yield self._check_auth_dict(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
|
@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
|
||||||
raise InteractiveAuthIncompleteError(ret)
|
raise InteractiveAuthIncompleteError(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
|
||||||
"""
|
"""
|
||||||
Adds the result of out-of-band authentication into an existing auth
|
Adds the result of out-of-band authentication into an existing auth
|
||||||
session. Currently used for adding the result of fallback auth.
|
session. Currently used for adding the result of fallback auth.
|
||||||
|
@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_session_id(self, clientdict):
|
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Gets the session ID for a client given the client dictionary
|
Gets the session ID for a client given the client dictionary
|
||||||
|
|
||||||
|
@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
|
||||||
clientdict: The dictionary sent by the client in the request
|
clientdict: The dictionary sent by the client in the request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str|None: The string session ID the client sent. If the client did
|
The string session ID the client sent. If the client did
|
||||||
not send a session ID, returns None.
|
not send a session ID, returns None.
|
||||||
"""
|
"""
|
||||||
sid = None
|
sid = None
|
||||||
|
@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
|
||||||
sid = authdict["session"]
|
sid = authdict["session"]
|
||||||
return sid
|
return sid
|
||||||
|
|
||||||
def set_session_data(self, session_id, key, value):
|
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Store a key-value pair into the sessions data associated with this
|
Store a key-value pair into the sessions data associated with this
|
||||||
request. This data is stored server-side and cannot be modified by
|
request. This data is stored server-side and cannot be modified by
|
||||||
the client.
|
the client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (string): The ID of this session as returned from check_auth
|
session_id: The ID of this session as returned from check_auth
|
||||||
key (string): The key to store the data under
|
key: The key to store the data under
|
||||||
value (any): The data to store
|
value: The data to store
|
||||||
"""
|
"""
|
||||||
sess = self._get_session_info(session_id)
|
sess = self._get_session_info(session_id)
|
||||||
sess.setdefault("serverdict", {})[key] = value
|
sess.setdefault("serverdict", {})[key] = value
|
||||||
self._save_session(sess)
|
self._save_session(sess)
|
||||||
|
|
||||||
def get_session_data(self, session_id, key, default=None):
|
def get_session_data(
|
||||||
|
self, session_id: str, key: str, default: Optional[Any] = None
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieve data stored with set_session_data
|
Retrieve data stored with set_session_data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (string): The ID of this session as returned from check_auth
|
session_id: The ID of this session as returned from check_auth
|
||||||
key (string): The key to store the data under
|
key: The key to store the data under
|
||||||
default (any): Value to return if the key has not been set
|
default: Value to return if the key has not been set
|
||||||
"""
|
"""
|
||||||
sess = self._get_session_info(session_id)
|
sess = self._get_session_info(session_id)
|
||||||
return sess.setdefault("serverdict", {}).get(key, default)
|
return sess.setdefault("serverdict", {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_auth_dict(self, authdict, clientip):
|
def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
|
||||||
"""Attempt to validate the auth dict provided by a client
|
"""Attempt to validate the auth dict provided by a client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
authdict (object): auth dict provided by the client
|
authdict: auth dict provided by the client
|
||||||
clientip (str): IP address of the client
|
clientip: IP address of the client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: result of the stage verification.
|
Deferred: result of the stage verification.
|
||||||
|
@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
|
||||||
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||||
return canonical_id
|
return canonical_id
|
||||||
|
|
||||||
def _get_params_recaptcha(self):
|
def _get_params_recaptcha(self) -> dict:
|
||||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||||
|
|
||||||
def _get_params_terms(self):
|
def _get_params_terms(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"policies": {
|
"policies": {
|
||||||
"privacy_policy": {
|
"privacy_policy": {
|
||||||
|
@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _auth_dict_for_flows(self, flows, session):
|
def _auth_dict_for_flows(
|
||||||
|
self, flows: List[List[str]], session: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
public_flows = []
|
public_flows = []
|
||||||
for f in flows:
|
for f in flows:
|
||||||
public_flows.append(f)
|
public_flows.append(f)
|
||||||
|
@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
|
||||||
LoginType.TERMS: self._get_params_terms,
|
LoginType.TERMS: self._get_params_terms,
|
||||||
}
|
}
|
||||||
|
|
||||||
params = {}
|
params = {} # type: Dict[str, Any]
|
||||||
|
|
||||||
for f in public_flows:
|
for f in public_flows:
|
||||||
for stage in f:
|
for stage in f:
|
||||||
|
@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
|
||||||
"params": params,
|
"params": params,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_session_info(self, session_id):
|
def _get_session_info(self, session_id: Optional[str]) -> dict:
|
||||||
|
"""
|
||||||
|
Gets or creates a session given a session ID.
|
||||||
|
|
||||||
|
The session can be used to track data across multiple requests, e.g. for
|
||||||
|
interactive authentication.
|
||||||
|
"""
|
||||||
if session_id not in self.sessions:
|
if session_id not in self.sessions:
|
||||||
session_id = None
|
session_id = None
|
||||||
|
|
||||||
|
@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
|
def get_access_token_for_user_id(
|
||||||
|
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Creates a new access token for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
|
|
||||||
|
@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
|
||||||
The device will be recorded in the table if it is not there already.
|
The device will be recorded in the table if it is not there already.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): canonical User ID
|
user_id: canonical User ID
|
||||||
device_id (str|None): the device ID to associate with the tokens.
|
device_id: the device ID to associate with the tokens.
|
||||||
None to leave the tokens unassociated with a device (deprecated:
|
None to leave the tokens unassociated with a device (deprecated:
|
||||||
we should always have a device ID)
|
we should always have a device ID)
|
||||||
valid_until_ms (int|None): when the token is valid until. None for
|
valid_until_ms: when the token is valid until. None for
|
||||||
no expiry.
|
no expiry.
|
||||||
Returns:
|
Returns:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
|
@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_user_exists(self, user_id):
|
def check_user_exists(self, user_id: str):
|
||||||
"""
|
"""
|
||||||
Checks to see if a user with the given id exists. Will check case
|
Checks to see if a user with the given id exists. Will check case
|
||||||
insensitively, but return None if there are multiple inexact matches.
|
insensitively, but return None if there are multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(unicode|bytes) user_id: complete @user:id
|
user_id: complete @user:id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: (unicode) canonical_user_id, or None if zero or
|
defer.Deferred: (unicode) canonical_user_id, or None if zero or
|
||||||
|
@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _find_user_id_and_pwd_hash(self, user_id):
|
def _find_user_id_and_pwd_hash(self, user_id: str):
|
||||||
"""Checks to see if a user with the given id exists. Will check case
|
"""Checks to see if a user with the given id exists. Will check case
|
||||||
insensitively, but will return None if there are multiple inexact
|
insensitively, but will return None if there are multiple inexact
|
||||||
matches.
|
matches.
|
||||||
|
@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_supported_login_types(self):
|
def get_supported_login_types(self) -> Iterable[str]:
|
||||||
"""Get a the login types supported for the /login API
|
"""Get a the login types supported for the /login API
|
||||||
|
|
||||||
By default this is just 'm.login.password' (unless password_enabled is
|
By default this is just 'm.login.password' (unless password_enabled is
|
||||||
|
@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
|
||||||
other login types.
|
other login types.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterable[str]: login types
|
login types
|
||||||
"""
|
"""
|
||||||
return self._supported_login_types
|
return self._supported_login_types
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def validate_login(self, username, login_submission):
|
def validate_login(self, username: str, login_submission: Dict[str, Any]):
|
||||||
"""Authenticates the user for the /login API
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
Also used by the user-interactive auth flow to validate
|
Also used by the user-interactive auth flow to validate
|
||||||
m.login.password auth types.
|
m.login.password auth types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
username (str): username supplied by the user
|
username: username supplied by the user
|
||||||
login_submission (dict): the whole of the login submission
|
login_submission: the whole of the login submission
|
||||||
(including 'type' and other relevant fields)
|
(including 'type' and other relevant fields)
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str, func]: canonical user id, and optional callback
|
Deferred[str, func]: canonical user id, and optional callback
|
||||||
|
@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
|
||||||
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_password_provider_3pid(self, medium, address, password):
|
def check_password_provider_3pid(self, medium: str, address: str, password: str):
|
||||||
"""Check if a password provider is able to validate a thirdparty login
|
"""Check if a password provider is able to validate a thirdparty login
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
medium (str): The medium of the 3pid (ex. email).
|
medium: The medium of the 3pid (ex. email).
|
||||||
address (str): The address of the 3pid (ex. jdoe@example.com).
|
address: The address of the 3pid (ex. jdoe@example.com).
|
||||||
password (str): The password of the user.
|
password: The password of the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[(str|None, func|None)]: A tuple of `(user_id,
|
Deferred[(str|None, func|None)]: A tuple of `(user_id,
|
||||||
|
@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_local_password(self, user_id, password):
|
def _check_local_password(self, user_id: str, password: str):
|
||||||
"""Authenticate a user against the local password database.
|
"""Authenticate a user against the local password database.
|
||||||
|
|
||||||
user_id is checked case insensitively, but will return None if there are
|
user_id is checked case insensitively, but will return None if there are
|
||||||
multiple inexact matches.
|
multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (unicode): complete @user:id
|
user_id: complete @user:id
|
||||||
password (unicode): the provided password
|
password: the provided password
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[unicode] the canonical_user_id, or Deferred[None] if
|
Deferred[unicode] the canonical_user_id, or Deferred[None] if
|
||||||
unknown user/bad password
|
unknown user/bad password
|
||||||
|
@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token: str):
|
||||||
auth_api = self.hs.get_auth()
|
auth_api = self.hs.get_auth()
|
||||||
user_id = None
|
user_id = None
|
||||||
try:
|
try:
|
||||||
|
@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token: str):
|
||||||
"""Invalidate a single access token
|
"""Invalidate a single access token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
access_token (str): access token to be deleted
|
access_token: access token to be deleted
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
|
@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_tokens_for_user(
|
def delete_access_tokens_for_user(
|
||||||
self, user_id, except_token_id=None, device_id=None
|
self,
|
||||||
|
user_id: str,
|
||||||
|
except_token_id: Optional[str] = None,
|
||||||
|
device_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Invalidate access tokens belonging to a user
|
"""Invalidate access tokens belonging to a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of user the tokens belong to
|
user_id: ID of user the tokens belong to
|
||||||
except_token_id (str|None): access_token ID which should *not* be
|
except_token_id: access_token ID which should *not* be deleted
|
||||||
deleted
|
device_id: ID of device the tokens are associated with.
|
||||||
device_id (str|None): ID of device the tokens are associated with.
|
|
||||||
If None, tokens associated with any device (or no device) will
|
If None, tokens associated with any device (or no device) will
|
||||||
be deleted
|
be deleted
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_threepid(self, user_id, medium, address, validated_at):
|
def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
|
||||||
# check if medium has a valid value
|
# check if medium has a valid value
|
||||||
if medium not in ["email", "msisdn"]:
|
if medium not in ["email", "msisdn"]:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_threepid(self, user_id, medium, address, id_server=None):
|
def delete_threepid(
|
||||||
|
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
|
||||||
|
):
|
||||||
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
||||||
from the local database.
|
from the local database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id: ID of user to remove the 3pid from.
|
||||||
medium (str)
|
medium: The medium of the 3pid being removed: "email" or "msisdn".
|
||||||
address (str)
|
address: The 3pid address to remove.
|
||||||
id_server (str|None): Use the given identity server when unbinding
|
id_server: Use the given identity server when unbinding
|
||||||
any threepids. If None then will attempt to unbind using the
|
any threepids. If None then will attempt to unbind using the
|
||||||
identity server specified when binding (if known).
|
identity server specified when binding (if known).
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
||||||
the identity server, False if identity server doesn't support the
|
the identity server, False if identity server doesn't support the
|
||||||
|
@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
|
||||||
yield self.store.user_delete_threepid(user_id, medium, address)
|
yield self.store.user_delete_threepid(user_id, medium, address)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _save_session(self, session):
|
def _save_session(self, session: Dict[str, Any]) -> None:
|
||||||
|
"""Update the last used time on the session to now and add it back to the session store."""
|
||||||
# TODO: Persistent storage
|
# TODO: Persistent storage
|
||||||
logger.debug("Saving session %s", session)
|
logger.debug("Saving session %s", session)
|
||||||
session["last_used"] = self.hs.get_clock().time_msec()
|
session["last_used"] = self.hs.get_clock().time_msec()
|
||||||
self.sessions[session["id"]] = session
|
self.sessions[session["id"]] = session
|
||||||
|
|
||||||
def hash(self, password):
|
def hash(self, password: str):
|
||||||
"""Computes a secure hash of password.
|
"""Computes a secure hash of password.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (unicode): Password to hash.
|
password: Password to hash.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(unicode): Hashed password.
|
Deferred(unicode): Hashed password.
|
||||||
|
@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return defer_to_thread(self.hs.get_reactor(), _do_hash)
|
return defer_to_thread(self.hs.get_reactor(), _do_hash)
|
||||||
|
|
||||||
def validate_hash(self, password, stored_hash):
|
def validate_hash(self, password: str, stored_hash: bytes):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password (unicode): Password to hash.
|
password: Password to hash.
|
||||||
stored_hash (bytes): Expected hash value.
|
stored_hash: Expected hash value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||||
|
@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
|
||||||
|
|
||||||
hs = attr.ib()
|
hs = attr.ib()
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None):
|
def generate_access_token(
|
||||||
|
self, user_id: str, extra_caveats: Optional[List[str]] = None
|
||||||
|
) -> str:
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
|
||||||
macaroon.add_first_party_caveat(caveat)
|
macaroon.add_first_party_caveat(caveat)
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
def generate_short_term_login_token(
|
||||||
"""
|
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
||||||
|
) -> str:
|
||||||
Args:
|
|
||||||
user_id (unicode):
|
|
||||||
duration_in_ms (int):
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
unicode
|
|
||||||
"""
|
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
|
@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_delete_pusher_token(self, user_id):
|
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def _generate_base_macaroon(self, user_id):
|
def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self.hs.config.server_name,
|
location=self.hs.config.server_name,
|
||||||
identifier="key",
|
identifier="key",
|
||||||
|
|
1
tox.ini
1
tox.ini
|
@ -185,6 +185,7 @@ commands = mypy \
|
||||||
synapse/federation/federation_client.py \
|
synapse/federation/federation_client.py \
|
||||||
synapse/federation/sender \
|
synapse/federation/sender \
|
||||||
synapse/federation/transport \
|
synapse/federation/transport \
|
||||||
|
synapse/handlers/auth.py \
|
||||||
synapse/handlers/directory.py \
|
synapse/handlers/directory.py \
|
||||||
synapse/handlers/presence.py \
|
synapse/handlers/presence.py \
|
||||||
synapse/handlers/sync.py \
|
synapse/handlers/sync.py \
|
||||||
|
|
Loading…
Reference in a new issue