forked from MirrorHub/synapse
Store an IdP ID in the OIDC session (#9109)
Again in preparation for handling more than one OIDC provider, add a new caveat to the macaroon used as an OIDC session cookie, which remembers which OIDC provider we are talking to. In future, when we get a callback, we'll need it to make sure we talk to the right IdP. As part of this, I'm adding an idp_id and idp_name field to the OIDC configuration object. They aren't yet documented, and we'll just use the old values by default.
This commit is contained in:
parent
20af310889
commit
4575ad0b1e
4 changed files with 42 additions and 10 deletions
1
changelog.d/9109.feature
Normal file
1
changelog.d/9109.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for multiple SSO Identity Providers.
|
|
@ -1,6 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2020 Quentin Gliech
|
# Copyright 2020 Quentin Gliech
|
||||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import string
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -38,7 +39,7 @@ class OIDCConfig(Config):
|
||||||
|
|
||||||
oidc_config = config.get("oidc_config")
|
oidc_config = config.get("oidc_config")
|
||||||
if oidc_config and oidc_config.get("enabled", False):
|
if oidc_config and oidc_config.get("enabled", False):
|
||||||
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
|
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
|
||||||
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
|
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
|
||||||
|
|
||||||
if not self.oidc_provider:
|
if not self.oidc_provider:
|
||||||
|
@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["issuer", "client_id", "client_secret"],
|
"required": ["issuer", "client_id", "client_secret"],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
||||||
|
"idp_name": {"type": "string"},
|
||||||
"discover": {"type": "boolean"},
|
"discover": {"type": "boolean"},
|
||||||
"issuer": {"type": "string"},
|
"issuer": {"type": "string"},
|
||||||
"client_id": {"type": "string"},
|
"client_id": {"type": "string"},
|
||||||
|
@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||||
"methods: %s" % (", ".join(missing_methods),)
|
"methods: %s" % (", ".join(missing_methods),)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MSC2858 will appy certain limits in what can be used as an IdP id, so let's
|
||||||
|
# enforce those limits now.
|
||||||
|
idp_id = oidc_config.get("idp_id", "oidc")
|
||||||
|
valid_idp_chars = set(string.ascii_letters + string.digits + "-._~")
|
||||||
|
|
||||||
|
if any(c not in valid_idp_chars for c in idp_id):
|
||||||
|
raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"')
|
||||||
|
|
||||||
return OidcProviderConfig(
|
return OidcProviderConfig(
|
||||||
|
idp_id=idp_id,
|
||||||
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
discover=oidc_config.get("discover", True),
|
discover=oidc_config.get("discover", True),
|
||||||
issuer=oidc_config["issuer"],
|
issuer=oidc_config["issuer"],
|
||||||
client_id=oidc_config["client_id"],
|
client_id=oidc_config["client_id"],
|
||||||
|
@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s(slots=True, frozen=True)
|
||||||
class OidcProviderConfig:
|
class OidcProviderConfig:
|
||||||
|
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||||
|
# table, as well as the query/path parameter used in the login protocol.
|
||||||
|
idp_id = attr.ib(type=str)
|
||||||
|
|
||||||
|
# user-facing name for this identity provider.
|
||||||
|
idp_name = attr.ib(type=str)
|
||||||
|
|
||||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||||
discover = attr.ib(type=bool)
|
discover = attr.ib(type=bool)
|
||||||
|
|
||||||
|
|
|
@ -175,7 +175,7 @@ class OidcHandler:
|
||||||
session_data = self._token_generator.verify_oidc_session_token(
|
session_data = self._token_generator.verify_oidc_session_token(
|
||||||
session, state
|
session, state
|
||||||
)
|
)
|
||||||
except MacaroonDeserializationException as e:
|
except (MacaroonDeserializationException, ValueError) as e:
|
||||||
logger.exception("Invalid session")
|
logger.exception("Invalid session")
|
||||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||||
return
|
return
|
||||||
|
@ -253,10 +253,10 @@ class OidcProvider:
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name = hs.config.server_name # type: str
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self.idp_id = "oidc"
|
self.idp_id = provider.idp_id
|
||||||
|
|
||||||
# user-facing name of this auth provider
|
# user-facing name of this auth provider
|
||||||
self.idp_name = "OIDC"
|
self.idp_name = provider.idp_name
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
|
@ -656,6 +656,7 @@ class OidcProvider:
|
||||||
cookie = self._token_generator.generate_oidc_session_token(
|
cookie = self._token_generator.generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
session_data=OidcSessionData(
|
session_data=OidcSessionData(
|
||||||
|
idp_id=self.idp_id,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url.decode(),
|
client_redirect_url=client_redirect_url.decode(),
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
|
@ -924,6 +925,7 @@ class OidcSessionTokenGenerator:
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = session")
|
macaroon.add_first_party_caveat("type = session")
|
||||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||||
|
macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
|
||||||
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
|
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
|
||||||
macaroon.add_first_party_caveat(
|
macaroon.add_first_party_caveat(
|
||||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||||
|
@ -952,6 +954,9 @@ class OidcSessionTokenGenerator:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The data extracted from the session cookie
|
The data extracted from the session cookie
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if an expected caveat is missing from the macaroon.
|
||||||
"""
|
"""
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||||
|
|
||||||
|
@ -960,6 +965,7 @@ class OidcSessionTokenGenerator:
|
||||||
v.satisfy_exact("type = session")
|
v.satisfy_exact("type = session")
|
||||||
v.satisfy_exact("state = %s" % (state,))
|
v.satisfy_exact("state = %s" % (state,))
|
||||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||||
# to always satisfy this.
|
# to always satisfy this.
|
||||||
|
@ -968,9 +974,9 @@ class OidcSessionTokenGenerator:
|
||||||
|
|
||||||
v.verify(macaroon, self._macaroon_secret_key)
|
v.verify(macaroon, self._macaroon_secret_key)
|
||||||
|
|
||||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
# Extract the session data from the token.
|
||||||
# `ui_auth_session_id` from the token.
|
|
||||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||||
|
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
||||||
client_redirect_url = self._get_value_from_macaroon(
|
client_redirect_url = self._get_value_from_macaroon(
|
||||||
macaroon, "client_redirect_url"
|
macaroon, "client_redirect_url"
|
||||||
)
|
)
|
||||||
|
@ -983,6 +989,7 @@ class OidcSessionTokenGenerator:
|
||||||
|
|
||||||
return OidcSessionData(
|
return OidcSessionData(
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
|
idp_id=idp_id,
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator:
|
||||||
The extracted value
|
The extracted value
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: if the caveat was not in the macaroon
|
ValueError: if the caveat was not in the macaroon
|
||||||
"""
|
"""
|
||||||
prefix = key + " = "
|
prefix = key + " = "
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
|
@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator:
|
||||||
class OidcSessionData:
|
class OidcSessionData:
|
||||||
"""The attributes which are stored in a OIDC session cookie"""
|
"""The attributes which are stored in a OIDC session cookie"""
|
||||||
|
|
||||||
|
# the Identity Provider being used
|
||||||
|
idp_id = attr.ib(type=str)
|
||||||
|
|
||||||
# The `nonce` parameter passed to the OIDC provider.
|
# The `nonce` parameter passed to the OIDC provider.
|
||||||
nonce = attr.ib(type=str)
|
nonce = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
|
@ -848,6 +848,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
return self.handler._token_generator.generate_oidc_session_token(
|
return self.handler._token_generator.generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
session_data=OidcSessionData(
|
session_data=OidcSessionData(
|
||||||
|
idp_id="oidc",
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
ui_auth_session_id=ui_auth_session_id,
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
|
@ -990,7 +991,7 @@ async def _make_callback_with_userinfo(
|
||||||
session = handler._token_generator.generate_oidc_session_token(
|
session = handler._token_generator.generate_oidc_session_token(
|
||||||
state=state,
|
state=state,
|
||||||
session_data=OidcSessionData(
|
session_data=OidcSessionData(
|
||||||
nonce="nonce", client_redirect_url=client_redirect_url,
|
idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
request = _build_callback_request("code", state, session)
|
request = _build_callback_request("code", state, session)
|
||||||
|
|
Loading…
Reference in a new issue