mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-21 02:11:55 +01:00
Merge pull request #2609 from matrix-org/rav/refactor_login
Refactor some logic from LoginRestServlet into AuthHandler
This commit is contained in:
commit
c31a7c3ff6
2 changed files with 79 additions and 58 deletions
|
@ -77,6 +77,12 @@ class AuthHandler(BaseHandler):
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
self._password_enabled = hs.config.password_enabled
|
||||||
|
|
||||||
|
login_types = set()
|
||||||
|
if self._password_enabled:
|
||||||
|
login_types.add(LoginType.PASSWORD)
|
||||||
|
self._supported_login_types = frozenset(login_types)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -266,10 +272,11 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
user_id = authdict["user"]
|
user_id = authdict["user"]
|
||||||
password = authdict["password"]
|
password = authdict["password"]
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(user_id, self.hs.hostname).to_string()
|
|
||||||
|
|
||||||
return self._check_password(user_id, password)
|
return self.validate_login(user_id, {
|
||||||
|
"type": LoginType.PASSWORD,
|
||||||
|
"password": password,
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
|
@ -398,23 +405,6 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
def validate_password_login(self, user_id, password):
|
|
||||||
"""
|
|
||||||
Authenticates the user with their username and password.
|
|
||||||
|
|
||||||
Used only by the v1 login API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): complete @user:id
|
|
||||||
password (str): Password
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: (str) canonical user id
|
|
||||||
Raises:
|
|
||||||
StoreError if there was a problem accessing the database
|
|
||||||
LoginError if there was an authentication problem.
|
|
||||||
"""
|
|
||||||
return self._check_password(user_id, password)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_access_token_for_user_id(self, user_id, device_id=None,
|
def get_access_token_for_user_id(self, user_id, device_id=None,
|
||||||
initial_display_name=None):
|
initial_display_name=None):
|
||||||
|
@ -501,26 +491,60 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def get_supported_login_types(self):
|
||||||
def _check_password(self, user_id, password):
|
"""Get a the login types supported for the /login API
|
||||||
"""Authenticate a user against the LDAP and local databases.
|
|
||||||
|
|
||||||
user_id is checked case insensitively against the local database, but
|
By default this is just 'm.login.password' (unless password_enabled is
|
||||||
will throw if there are multiple inexact matches.
|
False in the config file), but password auth providers can provide
|
||||||
|
other login types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[str]: login types
|
||||||
|
"""
|
||||||
|
return self._supported_login_types
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def validate_login(self, user_id, login_submission):
|
||||||
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
|
Also used by the user-interactive auth flow to validate
|
||||||
|
m.login.password auth types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): complete @user:id
|
user_id (str): user_id supplied by the user
|
||||||
|
login_submission (dict): the whole of the login submission
|
||||||
|
(including 'type' and other relevant fields)
|
||||||
Returns:
|
Returns:
|
||||||
(str) the canonical_user_id
|
Deferred[str]: canonical user id
|
||||||
Raises:
|
Raises:
|
||||||
LoginError if login fails
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not user_id.startswith('@'):
|
||||||
|
user_id = UserID(
|
||||||
|
user_id, self.hs.hostname
|
||||||
|
).to_string()
|
||||||
|
|
||||||
|
login_type = login_submission.get("type")
|
||||||
|
|
||||||
|
if login_type != LoginType.PASSWORD:
|
||||||
|
raise SynapseError(400, "Bad login type.")
|
||||||
|
if not self._password_enabled:
|
||||||
|
raise SynapseError(400, "Password login has been disabled.")
|
||||||
|
if "password" not in login_submission:
|
||||||
|
raise SynapseError(400, "Missing parameter: password")
|
||||||
|
|
||||||
|
password = login_submission["password"]
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
is_valid = yield provider.check_password(user_id, password)
|
is_valid = yield provider.check_password(user_id, password)
|
||||||
if is_valid:
|
if is_valid:
|
||||||
defer.returnValue(user_id)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
canonical_user_id = yield self._check_local_password(user_id, password)
|
canonical_user_id = yield self._check_local_password(
|
||||||
|
user_id, password,
|
||||||
|
)
|
||||||
|
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
defer.returnValue(canonical_user_id)
|
defer.returnValue(canonical_user_id)
|
||||||
|
|
|
@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
|
||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login$")
|
PATTERNS = client_path_patterns("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
|
||||||
SAML2_TYPE = "m.login.saml2"
|
SAML2_TYPE = "m.login.saml2"
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
TOKEN_TYPE = "m.login.token"
|
TOKEN_TYPE = "m.login.token"
|
||||||
|
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LoginRestServlet, self).__init__(hs)
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||||
self.password_enabled = hs.config.password_enabled
|
|
||||||
self.saml2_enabled = hs.config.saml2_enabled
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
self.jwt_enabled = hs.config.jwt_enabled
|
self.jwt_enabled = hs.config.jwt_enabled
|
||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
|
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
# fall back to the fallback API if they don't understand one of the
|
# fall back to the fallback API if they don't understand one of the
|
||||||
# login flow types returned.
|
# login flow types returned.
|
||||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||||
if self.password_enabled:
|
|
||||||
flows.append({"type": LoginRestServlet.PASS_TYPE})
|
flows.extend((
|
||||||
|
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||||
|
))
|
||||||
|
|
||||||
return (200, {"flows": flows})
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
|
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
login_submission = parse_json_object_from_request(request)
|
login_submission = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
if self.saml2_enabled and (login_submission["type"] ==
|
||||||
if not self.password_enabled:
|
LoginRestServlet.SAML2_TYPE):
|
||||||
raise SynapseError(400, "Password login has been disabled.")
|
|
||||||
|
|
||||||
result = yield self.do_password_login(login_submission)
|
|
||||||
defer.returnValue(result)
|
|
||||||
elif self.saml2_enabled and (login_submission["type"] ==
|
|
||||||
LoginRestServlet.SAML2_TYPE):
|
|
||||||
relay_state = ""
|
relay_state = ""
|
||||||
if "relay_state" in login_submission:
|
if "relay_state" in login_submission:
|
||||||
relay_state = "&RelayState=" + urllib.quote(
|
relay_state = "&RelayState=" + urllib.quote(
|
||||||
|
@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
result = yield self.do_token_login(login_submission)
|
result = yield self.do_token_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Bad login type.")
|
result = yield self._do_other_login(login_submission)
|
||||||
|
defer.returnValue(result)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_password_login(self, login_submission):
|
def _do_other_login(self, login_submission):
|
||||||
if "password" not in login_submission:
|
"""Handle non-token/saml/jwt logins
|
||||||
raise SynapseError(400, "Missing parameter: password")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
login_submission:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int, object): HTTP code/response
|
||||||
|
"""
|
||||||
login_submission_legacy_convert(login_submission)
|
login_submission_legacy_convert(login_submission)
|
||||||
|
|
||||||
if "identifier" not in login_submission:
|
if "identifier" not in login_submission:
|
||||||
|
@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
if "user" not in identifier:
|
if "user" not in identifier:
|
||||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
user_id = identifier["user"]
|
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID(
|
|
||||||
user_id, self.hs.hostname
|
|
||||||
).to_string()
|
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = yield auth_handler.validate_password_login(
|
canonical_user_id = yield auth_handler.validate_login(
|
||||||
user_id=user_id,
|
identifier["user"],
|
||||||
password=login_submission["password"],
|
login_submission,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_id = yield self._register_device(
|
||||||
|
canonical_user_id, login_submission,
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
access_token = yield auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id,
|
canonical_user_id, device_id,
|
||||||
login_submission.get("initial_device_display_name"),
|
login_submission.get("initial_device_display_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": canonical_user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
|
Loading…
Add table
Reference in a new issue