mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 23:03:51 +01:00
Pass module API to OIDC mapping provider (#16974)
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
This commit is contained in:
parent
05489d89c6
commit
74ab329eaa
3 changed files with 18 additions and 4 deletions
1
changelog.d/16974.misc
Normal file
1
changelog.d/16974.misc
Normal file
|
@ -0,0 +1 @@
|
|||
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
|
|
@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.
|
|||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `def __init__(self, parsed_config)`
|
||||
* `def __init__(self, parsed_config, module_api)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
|
||||
stable API available for extension modules.
|
||||
* `def parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
|
|
|
@ -65,6 +65,7 @@ from synapse.http.server import finish_request
|
|||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import Clock, json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
|
@ -421,9 +422,19 @@ class OidcProvider:
|
|||
# from the IdP's jwks_uri, if required.
|
||||
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
|
||||
|
||||
self._user_mapping_provider = provider.user_mapping_provider_class(
|
||||
provider.user_mapping_provider_config
|
||||
user_mapping_provider_init_method = (
|
||||
provider.user_mapping_provider_class.__init__
|
||||
)
|
||||
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
|
||||
self._user_mapping_provider = provider.user_mapping_provider_class(
|
||||
provider.user_mapping_provider_config,
|
||||
ModuleApi(hs, hs.get_auth_handler()),
|
||||
)
|
||||
else:
|
||||
self._user_mapping_provider = provider.user_mapping_provider_class(
|
||||
provider.user_mapping_provider_config,
|
||||
)
|
||||
|
||||
self._skip_verification = provider.skip_verification
|
||||
self._allow_existing_users = provider.allow_existing_users
|
||||
|
||||
|
@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
This is the default mapping provider.
|
||||
"""
|
||||
|
||||
def __init__(self, config: JinjaOidcMappingConfig):
|
||||
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
|
||||
self._config = config
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in a new issue