mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-21 02:41:56 +01:00
Merge pull request #2609 from matrix-org/rav/refactor_login
Refactor some logic from LoginRestServlet into AuthHandler
This commit is contained in:
commit
c31a7c3ff6
2 changed files with 79 additions and 58 deletions
|
@ -77,6 +77,12 @@ class AuthHandler(BaseHandler):
|
|||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
|
||||
login_types = set()
|
||||
if self._password_enabled:
|
||||
login_types.add(LoginType.PASSWORD)
|
||||
self._supported_login_types = frozenset(login_types)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
|
@ -266,10 +272,11 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
user_id = authdict["user"]
|
||||
password = authdict["password"]
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID(user_id, self.hs.hostname).to_string()
|
||||
|
||||
return self._check_password(user_id, password)
|
||||
return self.validate_login(user_id, {
|
||||
"type": LoginType.PASSWORD,
|
||||
"password": password,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
|
@ -398,23 +405,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def validate_password_login(self, user_id, password):
|
||||
"""
|
||||
Authenticates the user with their username and password.
|
||||
|
||||
Used only by the v1 login API.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
password (str): Password
|
||||
Returns:
|
||||
defer.Deferred: (str) canonical user id
|
||||
Raises:
|
||||
StoreError if there was a problem accessing the database
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
|
@ -501,26 +491,60 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""Authenticate a user against the LDAP and local databases.
|
||||
def get_supported_login_types(self):
|
||||
"""Get a the login types supported for the /login API
|
||||
|
||||
user_id is checked case insensitively against the local database, but
|
||||
will throw if there are multiple inexact matches.
|
||||
By default this is just 'm.login.password' (unless password_enabled is
|
||||
False in the config file), but password auth providers can provide
|
||||
other login types.
|
||||
|
||||
Returns:
|
||||
Iterable[str]: login types
|
||||
"""
|
||||
return self._supported_login_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_login(self, user_id, login_submission):
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate
|
||||
m.login.password auth types.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
user_id (str): user_id supplied by the user
|
||||
login_submission (dict): the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
Returns:
|
||||
(str) the canonical_user_id
|
||||
Deferred[str]: canonical user id
|
||||
Raises:
|
||||
LoginError if login fails
|
||||
StoreError if there was a problem accessing the database
|
||||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID(
|
||||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
login_type = login_submission.get("type")
|
||||
|
||||
if login_type != LoginType.PASSWORD:
|
||||
raise SynapseError(400, "Bad login type.")
|
||||
if not self._password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
if "password" not in login_submission:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
|
||||
password = login_submission["password"]
|
||||
for provider in self.password_providers:
|
||||
is_valid = yield provider.check_password(user_id, password)
|
||||
if is_valid:
|
||||
defer.returnValue(user_id)
|
||||
|
||||
canonical_user_id = yield self._check_local_password(user_id, password)
|
||||
canonical_user_id = yield self._check_local_password(
|
||||
user_id, password,
|
||||
)
|
||||
|
||||
if canonical_user_id:
|
||||
defer.returnValue(canonical_user_id)
|
||||
|
|
|
@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
|
||||
class LoginRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login$")
|
||||
PASS_TYPE = "m.login.password"
|
||||
SAML2_TYPE = "m.login.saml2"
|
||||
CAS_TYPE = "m.login.cas"
|
||||
TOKEN_TYPE = "m.login.token"
|
||||
|
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(LoginRestServlet, self).__init__(hs)
|
||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||
self.password_enabled = hs.config.password_enabled
|
||||
self.saml2_enabled = hs.config.saml2_enabled
|
||||
self.jwt_enabled = hs.config.jwt_enabled
|
||||
self.jwt_secret = hs.config.jwt_secret
|
||||
|
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
# fall back to the fallback API if they don't understand one of the
|
||||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
if self.password_enabled:
|
||||
flows.append({"type": LoginRestServlet.PASS_TYPE})
|
||||
|
||||
flows.extend((
|
||||
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||
))
|
||||
|
||||
return (200, {"flows": flows})
|
||||
|
||||
|
@ -133,13 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||
if not self.password_enabled:
|
||||
raise SynapseError(400, "Password login has been disabled.")
|
||||
|
||||
result = yield self.do_password_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
elif self.saml2_enabled and (login_submission["type"] ==
|
||||
if self.saml2_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.SAML2_TYPE):
|
||||
relay_state = ""
|
||||
if "relay_state" in login_submission:
|
||||
|
@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
result = yield self.do_token_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
else:
|
||||
raise SynapseError(400, "Bad login type.")
|
||||
result = yield self._do_other_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_password_login(self, login_submission):
|
||||
if "password" not in login_submission:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
def _do_other_login(self, login_submission):
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
||||
Args:
|
||||
login_submission:
|
||||
|
||||
Returns:
|
||||
(int, object): HTTP code/response
|
||||
"""
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
if "identifier" not in login_submission:
|
||||
|
@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
if "user" not in identifier:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
user_id = identifier["user"]
|
||||
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID(
|
||||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.auth_handler
|
||||
user_id = yield auth_handler.validate_password_login(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"],
|
||||
canonical_user_id = yield auth_handler.validate_login(
|
||||
identifier["user"],
|
||||
login_submission,
|
||||
)
|
||||
|
||||
device_id = yield self._register_device(
|
||||
canonical_user_id, login_submission,
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||
user_id, device_id,
|
||||
canonical_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name"),
|
||||
)
|
||||
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": canonical_user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
|
|
Loading…
Add table
Reference in a new issue