mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 15:31:19 +01:00
support custom login types for validating users
Wire the custom login type support from password providers into the UI-auth user-validation flows.
This commit is contained in:
parent
cc58e177f3
commit
da1010c83a
1 changed files with 57 additions and 24 deletions
|
@ -49,7 +49,6 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
self.checkers = {
|
self.checkers = {
|
||||||
LoginType.PASSWORD: self._check_password_auth,
|
|
||||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||||
LoginType.MSISDN: self._check_msisdn,
|
LoginType.MSISDN: self._check_msisdn,
|
||||||
|
@ -78,15 +77,20 @@ class AuthHandler(BaseHandler):
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._password_enabled = hs.config.password_enabled
|
self._password_enabled = hs.config.password_enabled
|
||||||
|
|
||||||
login_types = set()
|
# we keep this as a list despite the O(N^2) implication so that we can
|
||||||
|
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||||
|
# type in the list. (NB that the spec doesn't require us to do so and
|
||||||
|
# clients which favour types that they don't understand over those that
|
||||||
|
# they do are technically broken)
|
||||||
|
login_types = []
|
||||||
if self._password_enabled:
|
if self._password_enabled:
|
||||||
login_types.add(LoginType.PASSWORD)
|
login_types.append(LoginType.PASSWORD)
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "get_supported_login_types"):
|
if hasattr(provider, "get_supported_login_types"):
|
||||||
login_types.update(
|
for t in provider.get_supported_login_types().keys():
|
||||||
provider.get_supported_login_types().keys()
|
if t not in login_types:
|
||||||
)
|
login_types.append(t)
|
||||||
self._supported_login_types = frozenset(login_types)
|
self._supported_login_types = login_types
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
||||||
|
@ -116,14 +120,27 @@ class AuthHandler(BaseHandler):
|
||||||
a different user to `requester`
|
a different user to `requester`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we only support password login here
|
# build a list of supported flows
|
||||||
flows = [[LoginType.PASSWORD]]
|
flows = [
|
||||||
|
[login_type] for login_type in self._supported_login_types
|
||||||
|
]
|
||||||
|
|
||||||
result, params, _ = yield self.check_auth(
|
result, params, _ = yield self.check_auth(
|
||||||
flows, request_body, clientip,
|
flows, request_body, clientip,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = result[LoginType.PASSWORD]
|
# find the completed login type
|
||||||
|
for login_type in self._supported_login_types:
|
||||||
|
if login_type not in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_id = result[login_type]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# this can't happen
|
||||||
|
raise Exception(
|
||||||
|
"check_auth returned True but no successful login type",
|
||||||
|
)
|
||||||
|
|
||||||
# check that the UI auth matched the access token
|
# check that the UI auth matched the access token
|
||||||
if user_id != requester.user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
|
@ -210,14 +227,12 @@ class AuthHandler(BaseHandler):
|
||||||
errordict = {}
|
errordict = {}
|
||||||
if 'type' in authdict:
|
if 'type' in authdict:
|
||||||
login_type = authdict['type']
|
login_type = authdict['type']
|
||||||
if login_type not in self.checkers:
|
|
||||||
raise LoginError(400, "", Codes.UNRECOGNIZED)
|
|
||||||
try:
|
try:
|
||||||
result = yield self.checkers[login_type](authdict, clientip)
|
result = yield self._check_auth_dict(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
creds[login_type] = result
|
creds[login_type] = result
|
||||||
self._save_session(session)
|
self._save_session(session)
|
||||||
except LoginError, e:
|
except LoginError as e:
|
||||||
if login_type == LoginType.EMAIL_IDENTITY:
|
if login_type == LoginType.EMAIL_IDENTITY:
|
||||||
# riot used to have a bug where it would request a new
|
# riot used to have a bug where it would request a new
|
||||||
# validation token (thus sending a new email) each time it
|
# validation token (thus sending a new email) each time it
|
||||||
|
@ -226,7 +241,7 @@ class AuthHandler(BaseHandler):
|
||||||
#
|
#
|
||||||
# Grandfather in the old behaviour for now to avoid
|
# Grandfather in the old behaviour for now to avoid
|
||||||
# breaking old riot deployments.
|
# breaking old riot deployments.
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
# this step failed. Merge the error dict into the response
|
# this step failed. Merge the error dict into the response
|
||||||
# so that the client can have another go.
|
# so that the client can have another go.
|
||||||
|
@ -323,17 +338,35 @@ class AuthHandler(BaseHandler):
|
||||||
return sess.setdefault('serverdict', {}).get(key, default)
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_auth_dict(self, authdict, clientip):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
"""Attempt to validate the auth dict provided by a client
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
|
||||||
|
|
||||||
user_id = authdict["user"]
|
Args:
|
||||||
password = authdict["password"]
|
authdict (object): auth dict provided by the client
|
||||||
|
clientip (str): IP address of the client
|
||||||
|
|
||||||
(canonical_id, callback) = yield self.validate_login(user_id, {
|
Returns:
|
||||||
"type": LoginType.PASSWORD,
|
Deferred: result of the stage verification.
|
||||||
"password": password,
|
|
||||||
})
|
Raises:
|
||||||
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
|
"""
|
||||||
|
login_type = authdict['type']
|
||||||
|
checker = self.checkers.get(login_type)
|
||||||
|
if checker is not None:
|
||||||
|
res = yield checker(authdict, clientip)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
# build a v1-login-style dict out of the authdict and fall back to the
|
||||||
|
# v1 code
|
||||||
|
user_id = authdict.get("user")
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||||
defer.returnValue(canonical_id)
|
defer.returnValue(canonical_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
Loading…
Reference in a new issue