mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-30 02:34:04 +01:00
ModuleAPI SSO auth callbacks (#15207)
Signed-off-by: Andrii Yasynyshyn yasinishyn.a.n@gmail.com
This commit is contained in:
parent
579c6be5f6
commit
63d96bfc61
8 changed files with 56 additions and 2 deletions
1
changelog.d/15207.feature
Normal file
1
changelog.d/15207.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth.
|
|
@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user
|
||||||
represented by their Matrix user ID.
|
represented by their Matrix user ID.
|
||||||
|
|
||||||
If multiple modules implement this callback, Synapse runs them all in order.
|
If multiple modules implement this callback, Synapse runs them all in order.
|
||||||
|
|
||||||
|
### `on_user_login`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.98.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None
|
||||||
|
```
|
||||||
|
|
||||||
|
Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth.
|
||||||
|
represented by their Matrix user ID.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, Synapse runs them all in order.
|
||||||
|
|
|
@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue {
|
||||||
match l.iter().map(SimpleJsonValue::extract).collect() {
|
match l.iter().map(SimpleJsonValue::extract).collect() {
|
||||||
Ok(a) => Ok(JsonValue::Array(a)),
|
Ok(a) => Ok(JsonValue::Array(a)),
|
||||||
Err(e) => Err(PyTypeError::new_err(format!(
|
Err(e) => Err(PyTypeError::new_err(format!(
|
||||||
"Can't convert to JsonValue::Array: {}",
|
"Can't convert to JsonValue::Array: {e}"
|
||||||
e
|
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
} else if let Ok(v) = SimpleJsonValue::extract(ob) {
|
} else if let Ok(v) = SimpleJsonValue::extract(ob) {
|
||||||
|
|
|
@ -98,6 +98,22 @@ class AccountValidityHandler:
|
||||||
for callback in self._module_api_callbacks.on_user_registration_callbacks:
|
for callback in self._module_api_callbacks.on_user_registration_callbacks:
|
||||||
await callback(user_id)
|
await callback(user_id)
|
||||||
|
|
||||||
|
async def on_user_login(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
auth_provider_type: Optional[str],
|
||||||
|
auth_provider_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Tell third-party modules about a user logins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The mxID of the user.
|
||||||
|
auth_provider_type: The type of login.
|
||||||
|
auth_provider_id: The ID of the auth provider.
|
||||||
|
"""
|
||||||
|
for callback in self._module_api_callbacks.on_user_login_callbacks:
|
||||||
|
await callback(user_id, auth_provider_type, auth_provider_id)
|
||||||
|
|
||||||
@wrap_as_background_process("send_renewals")
|
@wrap_as_background_process("send_renewals")
|
||||||
async def _send_renewal_emails(self) -> None:
|
async def _send_renewal_emails(self) -> None:
|
||||||
"""Gets the list of users whose account is expiring in the amount of time
|
"""Gets the list of users whose account is expiring in the amount of time
|
||||||
|
|
|
@ -212,6 +212,7 @@ class AuthHandler:
|
||||||
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
|
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
|
||||||
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
||||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
|
||||||
|
self._account_validity_handler = hs.get_account_validity_handler()
|
||||||
|
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
|
@ -1783,6 +1784,13 @@ class AuthHandler:
|
||||||
client_redirect_url, "loginToken", login_token
|
client_redirect_url, "loginToken", login_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run post-login module callback handlers
|
||||||
|
await self._account_validity_handler.on_user_login(
|
||||||
|
user_id=registered_user_id,
|
||||||
|
auth_provider_type=LoginType.SSO,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
# if the client is whitelisted, we can redirect straight to it
|
# if the client is whitelisted, we can redirect straight to it
|
||||||
if client_redirect_url.startswith(self._whitelisted_sso_clients):
|
if client_redirect_url.startswith(self._whitelisted_sso_clients):
|
||||||
request.redirect(redirect_url)
|
request.redirect(redirect_url)
|
||||||
|
|
|
@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
|
||||||
ON_LEGACY_ADMIN_REQUEST,
|
ON_LEGACY_ADMIN_REQUEST,
|
||||||
ON_LEGACY_RENEW_CALLBACK,
|
ON_LEGACY_RENEW_CALLBACK,
|
||||||
ON_LEGACY_SEND_MAIL_CALLBACK,
|
ON_LEGACY_SEND_MAIL_CALLBACK,
|
||||||
|
ON_USER_LOGIN_CALLBACK,
|
||||||
ON_USER_REGISTRATION_CALLBACK,
|
ON_USER_REGISTRATION_CALLBACK,
|
||||||
)
|
)
|
||||||
from synapse.module_api.callbacks.spamchecker_callbacks import (
|
from synapse.module_api.callbacks.spamchecker_callbacks import (
|
||||||
|
@ -334,6 +335,7 @@ class ModuleApi:
|
||||||
*,
|
*,
|
||||||
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
|
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
|
||||||
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
|
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
|
||||||
|
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
|
||||||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
||||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
||||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
||||||
|
@ -345,6 +347,7 @@ class ModuleApi:
|
||||||
return self._callbacks.account_validity.register_callbacks(
|
return self._callbacks.account_validity.register_callbacks(
|
||||||
is_user_expired=is_user_expired,
|
is_user_expired=is_user_expired,
|
||||||
on_user_registration=on_user_registration,
|
on_user_registration=on_user_registration,
|
||||||
|
on_user_login=on_user_login,
|
||||||
on_legacy_send_mail=on_legacy_send_mail,
|
on_legacy_send_mail=on_legacy_send_mail,
|
||||||
on_legacy_renew=on_legacy_renew,
|
on_legacy_renew=on_legacy_renew,
|
||||||
on_legacy_admin_request=on_legacy_admin_request,
|
on_legacy_admin_request=on_legacy_admin_request,
|
||||||
|
|
|
@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
# Types for callbacks to be registered via the module api
|
# Types for callbacks to be registered via the module api
|
||||||
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
|
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
|
||||||
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
|
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
|
||||||
|
ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable]
|
||||||
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
|
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
|
||||||
# to `/_synapse/client/account_validity`. See `register_callbacks` below.
|
# to `/_synapse/client/account_validity`. See `register_callbacks` below.
|
||||||
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
|
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
|
||||||
|
@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
|
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
|
||||||
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
|
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
|
||||||
|
self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = []
|
||||||
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
|
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
|
||||||
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
|
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
|
||||||
|
|
||||||
|
@ -44,6 +46,7 @@ class AccountValidityModuleApiCallbacks:
|
||||||
self,
|
self,
|
||||||
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
|
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
|
||||||
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
|
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
|
||||||
|
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
|
||||||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
||||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
||||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
||||||
|
@ -55,6 +58,9 @@ class AccountValidityModuleApiCallbacks:
|
||||||
if on_user_registration is not None:
|
if on_user_registration is not None:
|
||||||
self.on_user_registration_callbacks.append(on_user_registration)
|
self.on_user_registration_callbacks.append(on_user_registration)
|
||||||
|
|
||||||
|
if on_user_login is not None:
|
||||||
|
self.on_user_login_callbacks.append(on_user_login)
|
||||||
|
|
||||||
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
|
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
|
||||||
# an admin one). As part of moving the feature into a module, we need to change
|
# an admin one). As part of moving the feature into a module, we need to change
|
||||||
# the path from /_matrix/client/unstable/account_validity/... to
|
# the path from /_matrix/client/unstable/account_validity/... to
|
||||||
|
|
|
@ -115,6 +115,7 @@ class LoginRestServlet(RestServlet):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
self._spam_checker = hs.get_module_api_callbacks().spam_checker
|
self._spam_checker = hs.get_module_api_callbacks().spam_checker
|
||||||
|
self._account_validity_handler = hs.get_account_validity_handler()
|
||||||
|
|
||||||
self._well_known_builder = WellKnownBuilder(hs)
|
self._well_known_builder = WellKnownBuilder(hs)
|
||||||
self._address_ratelimiter = Ratelimiter(
|
self._address_ratelimiter = Ratelimiter(
|
||||||
|
@ -470,6 +471,13 @@ class LoginRestServlet(RestServlet):
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# execute the callback
|
||||||
|
await self._account_validity_handler.on_user_login(
|
||||||
|
user_id,
|
||||||
|
auth_provider_type=login_submission.get("type"),
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
if valid_until_ms is not None:
|
if valid_until_ms is not None:
|
||||||
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
expires_in_ms = valid_until_ms - self.clock.time_msec()
|
||||||
result["expires_in_ms"] = expires_in_ms
|
result["expires_in_ms"] = expires_in_ms
|
||||||
|
|
Loading…
Reference in a new issue