diff --git a/changelog.d/16974.misc b/changelog.d/16974.misc new file mode 100644 index 000000000..bf0a13786 --- /dev/null +++ b/changelog.d/16974.misc @@ -0,0 +1 @@ +As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 77cc02c54..10c695029 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead. A custom mapping provider must specify the following methods: -* `def __init__(self, parsed_config)` +* `def __init__(self, parsed_config, module_api)` - Arguments: - `parsed_config` - A configuration object that is the return value of the `parse_config` method. You should set any configuration options needed by the module here. + - `module_api` - a `synapse.module_api.ModuleApi` object which provides the + stable API available for extension modules. * `def parse_config(config)` - This method should have the `@staticmethod` decoration. - Arguments: diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index fe13d82b6..ba67cc476 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -65,6 +65,7 @@ from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable +from synapse.module_api import ModuleApi from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import Clock, json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall @@ -421,9 +422,19 @@ class OidcProvider: # from the IdP's jwks_uri, if required. self._jwks = RetryOnExceptionCachedCall(self._load_jwks) - self._user_mapping_provider = provider.user_mapping_provider_class( - provider.user_mapping_provider_config + user_mapping_provider_init_method = ( + provider.user_mapping_provider_class.__init__ ) + if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3: + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config, + ModuleApi(hs, hs.get_auth_handler()), + ) + else: + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config, + ) + self._skip_verification = provider.skip_verification self._allow_existing_users = provider.allow_existing_users @@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): This is the default mapping provider. """ - def __init__(self, config: JinjaOidcMappingConfig): + def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi): self._config = config @staticmethod