Extract OIDCProviderConfig object

Collect all the config options which related to an OIDC provider into a single
object.
This commit is contained in:
Richard van der Hoff 2020-12-18 12:13:03 +00:00
parent 98a64b7f7f
commit 7cc9509eca
2 changed files with 140 additions and 62 deletions

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech # Copyright 2020 Quentin Gliech
# Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +14,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Type
import attr
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -25,65 +31,29 @@ class OIDCConfig(Config):
section = "oidc" section = "oidc"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.oidc_enabled = False self.oidc_provider = None # type: Optional[OidcProviderConfig]
oidc_config = config.get("oidc_config") oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not oidc_config or not oidc_config.get("enabled", False): if not self.oidc_provider:
return return
try: try:
check_requirements("oidc") check_requirements("oidc")
except DependencyException as e: except DependencyException as e:
raise ConfigError(e.message) raise ConfigError(e.message) from e
public_baseurl = self.public_baseurl public_baseurl = self.public_baseurl
if public_baseurl is None: if public_baseurl is None:
raise ConfigError("oidc_config requires a public_baseurl to be set") raise ConfigError("oidc_config requires a public_baseurl to be set")
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_enabled = True @property
self.oidc_discover = oidc_config.get("discover", True) def oidc_enabled(self) -> bool:
self.oidc_issuer = oidc_config["issuer"] # OIDC is enabled if we have a provider
self.oidc_client_id = oidc_config["client_id"] return bool(self.oidc_provider)
self.oidc_client_secret = oidc_config["client_secret"]
self.oidc_client_auth_method = oidc_config.get(
"client_auth_method", "client_secret_basic"
)
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.oidc_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
def generate_config_section(self, config_dir_path, server_name, **kwargs): def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\ return """\
@ -224,3 +194,108 @@ class OIDCConfig(Config):
""".format( """.format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
) )
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
Raises:
ConfigError if the configuration is malformed.
"""
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
ump_config, ("oidc_config", "user_mapping_provider")
)
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
"get_remote_user_id",
"map_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by oidc_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
return OidcProviderConfig(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"],
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
token_endpoint=oidc_config.get("token_endpoint"),
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
jwks_uri=oidc_config.get("jwks_uri"),
skip_verification=oidc_config.get("skip_verification", False),
user_profile_method=oidc_config.get("user_profile_method", "auto"),
allow_existing_users=oidc_config.get("allow_existing_users", False),
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
)
@attr.s
class OidcProviderConfig:
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
issuer = attr.ib(type=str)
# oauth2 client id to use
client_id = attr.ib(type=str)
# oauth2 client secret to use
client_secret = attr.ib(type=str)
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
client_auth_method = attr.ib(type=str)
# list of scopes to request
scopes = attr.ib(type=Collection[str])
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint = attr.ib(type=Optional[str])
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint = attr.ib(type=Optional[str])
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint = attr.ib(type=Optional[str])
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri = attr.ib(type=Optional[str])
# Whether to skip metadata verification
skip_verification = attr.ib(type=bool)
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
user_profile_method = attr.ib(type=str)
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users = attr.ib(type=bool)
# the class of the user mapping provider
user_mapping_provider_class = attr.ib(type=Type)
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()

View file

@ -94,27 +94,30 @@ class OidcHandler:
self._token_generator = OidcSessionTokenGenerator(hs) self._token_generator = OidcSessionTokenGenerator(hs)
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str provider = hs.config.oidc.oidc_provider
# we should not have been instantiated if there is no configured provider.
assert provider is not None
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth( self._client_auth = ClientAuth(
hs.config.oidc_client_id, provider.client_id, provider.client_secret, provider.client_auth_method,
hs.config.oidc_client_secret,
hs.config.oidc_client_auth_method,
) # type: ClientAuth ) # type: ClientAuth
self._client_auth_method = hs.config.oidc_client_auth_method # type: str self._client_auth_method = provider.client_auth_method
self._provider_metadata = OpenIDProviderMetadata( self._provider_metadata = OpenIDProviderMetadata(
issuer=hs.config.oidc_issuer, issuer=provider.issuer,
authorization_endpoint=hs.config.oidc_authorization_endpoint, authorization_endpoint=provider.authorization_endpoint,
token_endpoint=hs.config.oidc_token_endpoint, token_endpoint=provider.token_endpoint,
userinfo_endpoint=hs.config.oidc_userinfo_endpoint, userinfo_endpoint=provider.userinfo_endpoint,
jwks_uri=hs.config.oidc_jwks_uri, jwks_uri=provider.jwks_uri,
) # type: OpenIDProviderMetadata ) # type: OpenIDProviderMetadata
self._provider_needs_discovery = hs.config.oidc_discover # type: bool self._provider_needs_discovery = provider.discover
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( self._user_mapping_provider = provider.user_mapping_provider_class(
hs.config.oidc_user_mapping_provider_config provider.user_mapping_provider_config
) # type: OidcMappingProvider )
self._skip_verification = hs.config.oidc_skip_verification # type: bool self._skip_verification = provider.skip_verification
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str