mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 16:43:53 +01:00
Extract OIDCProviderConfig object
Collect all the config options which related to an OIDC provider into a single object.
This commit is contained in:
parent
98a64b7f7f
commit
7cc9509eca
2 changed files with 140 additions and 62 deletions
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +14,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.types import Collection, JsonDict
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
@ -25,65 +31,29 @@ class OIDCConfig(Config):
|
|||
section = "oidc"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.oidc_enabled = False
|
||||
self.oidc_provider = None # type: Optional[OidcProviderConfig]
|
||||
|
||||
oidc_config = config.get("oidc_config")
|
||||
if oidc_config and oidc_config.get("enabled", False):
|
||||
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
|
||||
|
||||
if not oidc_config or not oidc_config.get("enabled", False):
|
||||
if not self.oidc_provider:
|
||||
return
|
||||
|
||||
try:
|
||||
check_requirements("oidc")
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message)
|
||||
raise ConfigError(e.message) from e
|
||||
|
||||
public_baseurl = self.public_baseurl
|
||||
if public_baseurl is None:
|
||||
raise ConfigError("oidc_config requires a public_baseurl to be set")
|
||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||
|
||||
self.oidc_enabled = True
|
||||
self.oidc_discover = oidc_config.get("discover", True)
|
||||
self.oidc_issuer = oidc_config["issuer"]
|
||||
self.oidc_client_id = oidc_config["client_id"]
|
||||
self.oidc_client_secret = oidc_config["client_secret"]
|
||||
self.oidc_client_auth_method = oidc_config.get(
|
||||
"client_auth_method", "client_secret_basic"
|
||||
)
|
||||
self.oidc_scopes = oidc_config.get("scopes", ["openid"])
|
||||
self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
|
||||
self.oidc_token_endpoint = oidc_config.get("token_endpoint")
|
||||
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
|
||||
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
|
||||
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
|
||||
self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
|
||||
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
|
||||
|
||||
ump_config = oidc_config.get("user_mapping_provider", {})
|
||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||
ump_config.setdefault("config", {})
|
||||
|
||||
(
|
||||
self.oidc_user_mapping_provider_class,
|
||||
self.oidc_user_mapping_provider_config,
|
||||
) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
required_methods = [
|
||||
"get_remote_user_id",
|
||||
"map_user_attributes",
|
||||
]
|
||||
missing_methods = [
|
||||
method
|
||||
for method in required_methods
|
||||
if not hasattr(self.oidc_user_mapping_provider_class, method)
|
||||
]
|
||||
if missing_methods:
|
||||
raise ConfigError(
|
||||
"Class specified by oidc_config."
|
||||
"user_mapping_provider.module is missing required "
|
||||
"methods: %s" % (", ".join(missing_methods),)
|
||||
)
|
||||
@property
|
||||
def oidc_enabled(self) -> bool:
|
||||
# OIDC is enabled if we have a provider
|
||||
return bool(self.oidc_provider)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
|
@ -224,3 +194,108 @@ class OIDCConfig(Config):
|
|||
""".format(
|
||||
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
|
||||
)
|
||||
|
||||
|
||||
def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||
"""Take the configuration dict and parse it into an OidcProviderConfig
|
||||
|
||||
Raises:
|
||||
ConfigError if the configuration is malformed.
|
||||
"""
|
||||
ump_config = oidc_config.get("user_mapping_provider", {})
|
||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||
ump_config.setdefault("config", {})
|
||||
|
||||
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
|
||||
ump_config, ("oidc_config", "user_mapping_provider")
|
||||
)
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
required_methods = [
|
||||
"get_remote_user_id",
|
||||
"map_user_attributes",
|
||||
]
|
||||
missing_methods = [
|
||||
method
|
||||
for method in required_methods
|
||||
if not hasattr(user_mapping_provider_class, method)
|
||||
]
|
||||
if missing_methods:
|
||||
raise ConfigError(
|
||||
"Class specified by oidc_config."
|
||||
"user_mapping_provider.module is missing required "
|
||||
"methods: %s" % (", ".join(missing_methods),)
|
||||
)
|
||||
|
||||
return OidcProviderConfig(
|
||||
discover=oidc_config.get("discover", True),
|
||||
issuer=oidc_config["issuer"],
|
||||
client_id=oidc_config["client_id"],
|
||||
client_secret=oidc_config["client_secret"],
|
||||
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
||||
scopes=oidc_config.get("scopes", ["openid"]),
|
||||
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
||||
token_endpoint=oidc_config.get("token_endpoint"),
|
||||
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
|
||||
jwks_uri=oidc_config.get("jwks_uri"),
|
||||
skip_verification=oidc_config.get("skip_verification", False),
|
||||
user_profile_method=oidc_config.get("user_profile_method", "auto"),
|
||||
allow_existing_users=oidc_config.get("allow_existing_users", False),
|
||||
user_mapping_provider_class=user_mapping_provider_class,
|
||||
user_mapping_provider_config=user_mapping_provider_config,
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class OidcProviderConfig:
|
||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||
discover = attr.ib(type=bool)
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
||||
# discover the provider's endpoints.
|
||||
issuer = attr.ib(type=str)
|
||||
|
||||
# oauth2 client id to use
|
||||
client_id = attr.ib(type=str)
|
||||
|
||||
# oauth2 client secret to use
|
||||
client_secret = attr.ib(type=str)
|
||||
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
||||
# 'none'.
|
||||
client_auth_method = attr.ib(type=str)
|
||||
|
||||
# list of scopes to request
|
||||
scopes = attr.ib(type=Collection[str])
|
||||
|
||||
# the oauth2 authorization endpoint. Required if discovery is disabled.
|
||||
authorization_endpoint = attr.ib(type=Optional[str])
|
||||
|
||||
# the oauth2 token endpoint. Required if discovery is disabled.
|
||||
token_endpoint = attr.ib(type=Optional[str])
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
||||
# "openid" scope is not requested.
|
||||
userinfo_endpoint = attr.ib(type=Optional[str])
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
||||
# "openid" scope is used.
|
||||
jwks_uri = attr.ib(type=Optional[str])
|
||||
|
||||
# Whether to skip metadata verification
|
||||
skip_verification = attr.ib(type=bool)
|
||||
|
||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
||||
# values are: "auto" or "userinfo_endpoint".
|
||||
user_profile_method = attr.ib(type=str)
|
||||
|
||||
# whether to allow a user logging in via OIDC to match a pre-existing account
|
||||
# instead of failing
|
||||
allow_existing_users = attr.ib(type=bool)
|
||||
|
||||
# the class of the user mapping provider
|
||||
user_mapping_provider_class = attr.ib(type=Type)
|
||||
|
||||
# the config of the user mapping provider
|
||||
user_mapping_provider_config = attr.ib()
|
||||
|
|
|
@ -94,27 +94,30 @@ class OidcHandler:
|
|||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||
|
||||
provider = hs.config.oidc.oidc_provider
|
||||
# we should not have been instantiated if there is no configured provider.
|
||||
assert provider is not None
|
||||
|
||||
self._scopes = provider.scopes
|
||||
self._user_profile_method = provider.user_profile_method
|
||||
self._client_auth = ClientAuth(
|
||||
hs.config.oidc_client_id,
|
||||
hs.config.oidc_client_secret,
|
||||
hs.config.oidc_client_auth_method,
|
||||
provider.client_id, provider.client_secret, provider.client_auth_method,
|
||||
) # type: ClientAuth
|
||||
self._client_auth_method = hs.config.oidc_client_auth_method # type: str
|
||||
self._client_auth_method = provider.client_auth_method
|
||||
self._provider_metadata = OpenIDProviderMetadata(
|
||||
issuer=hs.config.oidc_issuer,
|
||||
authorization_endpoint=hs.config.oidc_authorization_endpoint,
|
||||
token_endpoint=hs.config.oidc_token_endpoint,
|
||||
userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
|
||||
jwks_uri=hs.config.oidc_jwks_uri,
|
||||
issuer=provider.issuer,
|
||||
authorization_endpoint=provider.authorization_endpoint,
|
||||
token_endpoint=provider.token_endpoint,
|
||||
userinfo_endpoint=provider.userinfo_endpoint,
|
||||
jwks_uri=provider.jwks_uri,
|
||||
) # type: OpenIDProviderMetadata
|
||||
self._provider_needs_discovery = hs.config.oidc_discover # type: bool
|
||||
self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
|
||||
hs.config.oidc_user_mapping_provider_config
|
||||
) # type: OidcMappingProvider
|
||||
self._skip_verification = hs.config.oidc_skip_verification # type: bool
|
||||
self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
|
||||
self._provider_needs_discovery = provider.discover
|
||||
self._user_mapping_provider = provider.user_mapping_provider_class(
|
||||
provider.user_mapping_provider_config
|
||||
)
|
||||
self._skip_verification = provider.skip_verification
|
||||
self._allow_existing_users = provider.allow_existing_users
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
|
|
Loading…
Reference in a new issue