UI Auth via SSO: redirect the user to an appropriate SSO. (#9081)

If we have integrations with multiple identity providers, when the user does a UI Auth, we need to redirect them to the right one.

There are a few steps to this. First of all we actually need to store the userid of the user we are trying to validate in the UIA session, since the /auth/sso/fallback/web request is unauthenticated.

Then, once we get the /auth/sso/fallback/web request, we can fish the user id out of the session, and use it to look up the external id mappings, and hence pick an SSO provider for them.
This commit is contained in:
Richard van der Hoff 2021-01-12 17:38:03 +00:00 committed by GitHub
parent 723b19748a
commit 789d9ebad3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 133 additions and 60 deletions

1
changelog.d/9081.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View file

@ -50,7 +50,10 @@ from synapse.api.errors import (
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants,
)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import finish_request, respond_with_html
@ -335,10 +338,10 @@ class AuthHandler(BaseHandler):
request_body.pop("auth", None) request_body.pop("auth", None)
return request_body, None return request_body, None
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts # Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows # build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types( supported_ui_auth_types = await self._get_available_ui_auth_types(
@ -346,13 +349,16 @@ class AuthHandler(BaseHandler):
) )
flows = [[login_type] for login_type in supported_ui_auth_types] flows = [[login_type] for login_type in supported_ui_auth_types]
def get_new_session_data() -> JsonDict:
return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
try: try:
result, params, session_id = await self.check_ui_auth( result, params, session_id = await self.check_ui_auth(
flows, request, request_body, description flows, request, request_body, description, get_new_session_data,
) )
except LoginError: except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise). # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(user_id) self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise raise
# find the completed login type # find the completed login type
@ -360,14 +366,14 @@ class AuthHandler(BaseHandler):
if login_type not in result: if login_type not in result:
continue continue
user_id = result[login_type] validated_user_id = result[login_type]
break break
else: else:
# this can't happen # this can't happen
raise Exception("check_auth returned True but no successful login type") raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token # check that the UI auth matched the access token
if user_id != requester.user.to_string(): if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth") raise AuthError(403, "Invalid auth")
# Note that the access token has been validated. # Note that the access token has been validated.
@ -399,13 +405,9 @@ class AuthHandler(BaseHandler):
# if sso is enabled, allow the user to log in via SSO iff they have a mapping # if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid. # from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: if await self.hs.get_sso_handler().get_identity_providers_for_user(
if await self.store.get_external_ids_by_user(user.to_string()): user.to_string()
ui_auth_types.add(LoginType.SSO) ):
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO) ui_auth_types.add(LoginType.SSO)
return ui_auth_types return ui_auth_types
@ -424,6 +426,7 @@ class AuthHandler(BaseHandler):
request: SynapseRequest, request: SynapseRequest,
clientdict: Dict[str, Any], clientdict: Dict[str, Any],
description: str, description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]: ) -> Tuple[dict, dict, str]:
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
@ -447,6 +450,13 @@ class AuthHandler(BaseHandler):
description: A human readable string to be displayed to the user that description: A human readable string to be displayed to the user that
describes the operation happening on their account. describes the operation happening on their account.
get_new_session_data:
an optional callback which will be called when starting a new session.
it should return data to be stored as part of the session.
The keys of the returned data should be entries in
UIAuthSessionDataConstants.
Returns: Returns:
A tuple of (creds, params, session_id). A tuple of (creds, params, session_id).
@ -474,10 +484,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session. # If there's no session ID, create a new session.
if not sid: if not sid:
new_session_data = get_new_session_data() if get_new_session_data else {}
session = await self.store.create_ui_auth_session( session = await self.store.create_ui_auth_session(
clientdict, uri, method, description clientdict, uri, method, description
) )
for k, v in new_session_data.items():
await self.set_session_data(session.session_id, k, v)
else: else:
try: try:
session = await self.store.get_ui_auth_session(sid) session = await self.store.get_ui_auth_session(sid)
@ -639,7 +654,8 @@ class AuthHandler(BaseHandler):
Args: Args:
session_id: The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key: The key to store the data under key: The key to store the data under. An entry from
UIAuthSessionDataConstants.
value: The data to store value: The data to store
""" """
try: try:
@ -655,7 +671,8 @@ class AuthHandler(BaseHandler):
Args: Args:
session_id: The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key: The key to store the data under key: The key the data was stored under. An entry from
UIAuthSessionDataConstants.
default: Value to return if the key has not been set default: Value to return if the key has not been set
""" """
try: try:
@ -1329,12 +1346,12 @@ class AuthHandler(BaseHandler):
else: else:
return False return False
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
""" """
Get the HTML for the SSO redirect confirmation page. Get the HTML for the SSO redirect confirmation page.
Args: Args:
redirect_url: The URL to redirect to the SSO provider. request: The incoming HTTP request
session_id: The user interactive authentication session ID. session_id: The user interactive authentication session ID.
Returns: Returns:
@ -1344,6 +1361,35 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id) session = await self.store.get_ui_auth_session(session_id)
except StoreError: except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
)
if not idps:
# we checked that the user had some remote identities before offering an SSO
# flow, so either it's been deleted or the client has requested SSO despite
# it not being offered.
raise SynapseError(400, "User has no SSO identities")
# for now, just pick one
idp_id, sso_auth_provider = next(iter(idps.items()))
if len(idps) > 0:
logger.warning(
"User %r has previously logged in with multiple SSO IdPs; arbitrarily "
"picking %r",
user_id_to_verify,
idp_id,
)
redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session_id
)
return self._sso_auth_confirm_template.render( return self._sso_auth_confirm_template.render(
description=session.description, redirect_url=redirect_url, description=session.description, redirect_url=redirect_url,
) )

View file

@ -167,6 +167,37 @@ class SsoHandler:
"""Get the configured identity providers""" """Get the configured identity providers"""
return self._identity_providers return self._identity_providers
async def get_identity_providers_for_user(
self, user_id: str
) -> Mapping[str, SsoIdentityProvider]:
"""Get the SsoIdentityProviders which a user has used
Given a user id, get the identity providers that that user has used to log in
with in the past (and thus could use to re-identify themselves for UI Auth).
Args:
user_id: MXID of user to look up
Raises:
a map of idp_id to SsoIdentityProvider
"""
external_ids = await self._store.get_external_ids_by_user(user_id)
valid_idps = {}
for idp_id, _ in external_ids:
idp = self._identity_providers.get(idp_id)
if not idp:
logger.warning(
"User %r has an SSO mapping for IdP %r, but this is no longer "
"configured.",
user_id,
idp_id,
)
else:
valid_idps[idp_id] = idp
return valid_idps
def render_error( def render_error(
self, self,
request: Request, request: Request,

View file

@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
""" """
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
class UIAuthSessionDataConstants:
"""Constants for use with AuthHandler.set_session_data"""
# used during registration and password reset to store a hashed copy of the
# password, so that the client does not need to submit it each time.
PASSWORD_HASH = "password_hash"
# used during registration to store the mxid of the registered user
REGISTERED_USER_ID = "registered_user_id"
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"

View file

@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError, ThreepidValidationError,
) )
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -200,7 +202,9 @@ class PasswordRestServlet(RestServlet):
if new_password: if new_password:
password_hash = await self.auth_handler.hash(new_password) password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
) )
raise raise
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -222,7 +226,9 @@ class PasswordRestServlet(RestServlet):
if new_password: if new_password:
password_hash = await self.auth_handler.hash(new_password) password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
) )
raise raise
@ -255,7 +261,7 @@ class PasswordRestServlet(RestServlet):
password_hash = await self.auth_handler.hash(new_password) password_hash = await self.auth_handler.hash(new_password)
elif session_id is not None: elif session_id is not None:
password_hash = await self.auth_handler.get_session_data( password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
) )
else: else:
# UI validation was skipped, but the request did not include a new # UI validation was skipped, but the request did not include a new

View file

@ -19,7 +19,6 @@ from typing import TYPE_CHECKING
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
# SSO configuration.
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self.recaptcha_template = hs.config.recaptcha_template self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template self.success_template = hs.config.fallback_success_template
@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO: elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to # Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider. # re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
if self._cas_enabled:
sso_auth_provider = self._cas_handler # type: SsoIdentityProvider
elif self._saml_enabled:
sso_auth_provider = self._saml_handler
elif self._oidc_enabled:
sso_auth_provider = self._oidc_handler
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
sso_redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session
)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")

View file

@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though, # user here. We carry on and go through the auth checks though,
# for paranoia. # for paranoia.
registered_user_id = await self.auth_handler.get_session_data( registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
) )
# Extract the previously-hashed password from the session. # Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data( password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
) )
# Ensure that the username is valid. # Ensure that the username is valid.
@ -528,7 +529,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password: if not password_hash and password:
password_hash = await self.auth_handler.hash(password) password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
) )
raise raise
@ -629,7 +632,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user # Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified). # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data( await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id session_id,
UIAuthSessionDataConstants.REGISTERED_USER_ID,
registered_user_id,
) )
registered = True registered = True