Pass module API to OIDC mapping provider (#16974)

As done for SAML mapping provider, let's pass the module API to the OIDC
one so the mapper can do more logic in its code.
This commit is contained in:
Mathieu Velten 2024-03-19 18:20:10 +01:00 committed by GitHub
parent 05489d89c6
commit 74ab329eaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

1
changelog.d/16974.misc Normal file
View File

@ -0,0 +1 @@
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.

View File

@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
* `def __init__(self, parsed_config)`
* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:

View File

@ -65,6 +65,7 @@ from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
@ -421,9 +422,19 @@ class OidcProvider:
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
user_mapping_provider_init_method = (
provider.user_mapping_provider_class.__init__
)
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
else:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
)
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""
def __init__(self, config: JinjaOidcMappingConfig):
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config
@staticmethod