diff --git a/changelog.d/9626.feature b/changelog.d/9626.feature new file mode 100644 index 0000000000..eacba6201b --- /dev/null +++ b/changelog.d/9626.feature @@ -0,0 +1 @@ +Tell spam checker modules about the SSO IdP a user registered through if one was used. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 2020eb9006..52947f605e 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -69,7 +69,13 @@ class ExampleSpamChecker: async def check_username_for_spam(self, user_profile): return False # allow all usernames - async def check_registration_for_spam(self, email_threepid, username, request_info): + async def check_registration_for_spam( + self, + email_threepid, + username, + request_info, + auth_provider_id, + ): return RegistrationBehaviour.ALLOW # allow all registrations async def check_media_file_for_spam(self, file_wrapper, file_info): diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 8cfc0bb3cb..a9185987a2 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from synapse.rest.media.v1._base import FileInfo @@ -27,6 +28,8 @@ if TYPE_CHECKING: import synapse.events import synapse.server +logger = logging.getLogger(__name__) + class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): @@ -190,6 +193,7 @@ class SpamChecker: email_threepid: Optional[dict], username: Optional[str], request_info: Collection[Tuple[str, str]], + auth_provider_id: Optional[str] = None, ) -> RegistrationBehaviour: """Checks if we should allow the given registration request. @@ -198,6 +202,9 @@ class SpamChecker: username: The request user name, if any request_info: List of tuples of user agent and IP that were used during the registration process. + auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", + "cas". If any. Note this does not include users registered + via a password provider. Returns: Enum for how the request should be handled @@ -208,9 +215,25 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = await maybe_awaitable( - checker(email_threepid, username, request_info) - ) + # Provide auth_provider_id if the function supports it + checker_args = inspect.signature(checker) + if len(checker_args.parameters) == 4: + d = checker( + email_threepid, + username, + request_info, + auth_provider_id, + ) + elif len(checker_args.parameters) == 3: + d = checker(email_threepid, username, request_info) + else: + logger.error( + "Invalid signature for %s.check_registration_for_spam. Denying registration", + spam_checker.__module__, + ) + return RegistrationBehaviour.DENY + + behaviour = await maybe_awaitable(d) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cd001e87c7..1abc8875cb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -198,8 +198,7 @@ class RegistrationHandler(BaseHandler): admin api, otherwise False. user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. Returns: The registered user_id. Raises: @@ -211,6 +210,7 @@ class RegistrationHandler(BaseHandler): threepid, localpart, user_agent_ips or [], + auth_provider_id=auth_provider_id, ) if result == RegistrationBehaviour.DENY: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index bdf3d0a8a2..94b6903594 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -517,6 +517,37 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(requester.shadow_banned) + def test_spam_checker_receives_sso_type(self): + """Test rejecting registration based on SSO type""" + + class BanBadIdPUser: + def check_registration_for_spam( + self, email_threepid, username, request_info, auth_provider_id=None + ): + # Reject any user coming from CAS and whose username contains profanity + if auth_provider_id == "cas" and "flimflob" in username: + return RegistrationBehaviour.DENY + return RegistrationBehaviour.ALLOW + + # Configure a spam checker that denies a certain user on a specific IdP + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanBadIdPUser()] + + f = self.get_failure( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), + SynapseError, + ) + exception = f.value + + # We return 429 from the spam checker for denied registrations + self.assertIsInstance(exception, SynapseError) + self.assertEqual(exception.code, 429) + + # Check the same username can register using SAML + self.get_success( + self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") + ) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ):