Allow denying or shadow banning registrations via the spam checker (#8034)

This commit is contained in:
Patrick Cloke 2020-08-20 15:42:58 -04:00 committed by GitHub
parent e259d63f73
commit 3f91638da6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 258 additions and 18 deletions

1
changelog.d/8034.feature Normal file
View file

@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

View file

@ -15,9 +15,10 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any, Dict, List from typing import Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import SpamCheckerApi from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
from synapse.types import Collection
MYPY = False MYPY = False
if MYPY: if MYPY:
@ -160,3 +161,33 @@ class SpamChecker(object):
return True return True
return False return False
def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
Args:
email_threepid: The email threepid used for registering, if any
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
Returns:
Enum for how the request should be handled
"""
for spam_checker in self.spam_checkers:
# For backwards compatibility, only run if the method exists on the
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
behaviour = checker(email_threepid, username, request_info)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
return RegistrationBehaviour.ALLOW

View file

@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
# authentication flow. # authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict) await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
)
if not authdict: if not authdict:
raise InteractiveAuthIncompleteError( raise InteractiveAuthIncompleteError(
session.session_id, self._auth_dict_for_flows(flows, session.session_id) session.session_id, self._auth_dict_for_flows(flows, session.session_id)

View file

@ -35,6 +35,7 @@ class CasHandler:
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self._hostname = hs.hostname self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
@ -210,8 +211,16 @@ class CasHandler:
else: else:
if not registered_user_id: if not registered_user_id:
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name localpart=localpart,
default_display_name=user_display_name,
user_agent_ips=(user_agent, ip_address),
) )
await self._auth_handler.complete_sso_login( await self._auth_handler.complete_sso_login(

View file

@ -93,6 +93,7 @@ class OidcHandler:
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str] self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth( self._client_auth = ClientAuth(
@ -689,9 +690,17 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e)) self._render_error(request, "invalid_token", str(e))
return return
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
user_id = await self._map_userinfo_to_user(userinfo, token) user_id = await self._map_userinfo_to_user(
userinfo, token, user_agent, ip_address
)
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e)) self._render_error(request, "mapping_error", str(e))
@ -828,7 +837,9 @@ class OidcHandler:
now = self._clock.time_msec() now = self._clock.time_msec()
return now < expiry return now < expiry
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: async def _map_userinfo_to_user(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
) -> str:
"""Maps a UserInfo object to a mxid. """Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim UserInfo should have a claim that uniquely identifies users. This claim
@ -843,6 +854,8 @@ class OidcHandler:
Args: Args:
userinfo: an object representing the user userinfo: an object representing the user
token: a dict with the tokens obtained from the provider token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Raises: Raises:
MappingException: if there was an error while mapping some properties MappingException: if there was an error while mapping some properties
@ -899,7 +912,9 @@ class OidcHandler:
# It's the first time this user is logging in and the mapped mxid was # It's the first time this user is logging in and the mapped mxid was
# not taken, register the user # not taken, register the user
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=attributes["display_name"], localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id( await self._datastore.record_user_external_id(

View file

@ -26,6 +26,7 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet, ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet, ReplicationRegisterServlet,
) )
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.spam_checker = hs.get_spam_checker()
if hs.config.worker_app: if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client( self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler):
address=None, address=None,
bind_emails=[], bind_emails=[],
by_admin=False, by_admin=False,
shadow_banned=False, user_agent_ips=None,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account. bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the by_admin (bool): True if this registration is being made via the
admin api, otherwise False. admin api, otherwise False.
shadow_banned (bool): Shadow-ban the created user. user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
during the registration process.
Returns: Returns:
str: user_id str: user_id
Raises: Raises:
@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler):
""" """
self.check_registration_ratelimit(address) self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
if result == RegistrationBehaviour.DENY:
logger.info(
"Blocked registration of %r", localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
raise SynapseError(429, "Rate limited")
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
"Shadow banning registration of %r", localpart,
)
# do not check_auth_blocking if the call is coming through the Admin API # do not check_auth_blocking if the call is coming through the Admin API
if not by_admin: if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid) await self.auth.check_auth_blocking(threepid=threepid)

View file

@ -54,6 +54,7 @@ class Saml2SessionData:
class SamlHandler: class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -133,8 +134,14 @@ class SamlHandler:
# the dict. # the dict.
self.expire_sessions() self.expire_sessions()
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
user_id, current_session = await self._map_saml_response_to_user( user_id, current_session = await self._map_saml_response_to_user(
resp_bytes, relay_state resp_bytes, relay_state, user_agent, ip_address
) )
# Complete the interactive auth session or the login. # Complete the interactive auth session or the login.
@ -147,7 +154,11 @@ class SamlHandler:
await self._auth_handler.complete_sso_login(user_id, request, relay_state) await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user( async def _map_saml_response_to_user(
self, resp_bytes: str, client_redirect_url: str self,
resp_bytes: str,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]: ) -> Tuple[str, Optional[Saml2SessionData]]:
""" """
Given a sample response, retrieve the cached session and user for it. Given a sample response, retrieve the cached session and user for it.
@ -155,6 +166,8 @@ class SamlHandler:
Args: Args:
resp_bytes: The SAML response. resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client. client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns: Returns:
Tuple of the user ID and SAML session associated with this response. Tuple of the user ID and SAML session associated with this response.
@ -291,6 +304,7 @@ class SamlHandler:
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
bind_emails=emails, bind_emails=emails,
user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id( await self._datastore.record_user_external_id(

View file

@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE, Codes.THREEPID_IN_USE,
) )
entries = await self.store.get_user_agents_ips_to_ui_auth_session(
session_id
)
registered_user_id = await self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password_hash=password_hash, password_hash=password_hash,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
threepid=threepid, threepid=threepid,
address=client_addr, address=client_addr,
user_agent_ips=entries,
) )
# Necessary due to auth checks prior to the threepid being # Necessary due to auth checks prior to the threepid being
# written to the db # written to the db

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from enum import Enum
from twisted.internet import defer from twisted.internet import defer
@ -25,6 +26,16 @@ if MYPY:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegistrationBehaviour(Enum):
"""
Enum to define whether a registration request should allowed, denied, or shadow-banned.
"""
ALLOW = "allow"
SHADOW_BAN = "shadow_ban"
DENY = "deny"
class SpamCheckerApi(object): class SpamCheckerApi(object):
"""A proxy object that gets passed to spam checkers so they can get """A proxy object that gets passed to spam checkers so they can get
access to rooms and other relevant information. access to rooms and other relevant information.

View file

@ -0,0 +1,25 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A table of the IP address and user-agent used to complete each step of a
-- user-interactive authentication session.
CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
session_id TEXT NOT NULL,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
UNIQUE (session_id, ip, user_agent),
FOREIGN KEY (session_id)
REFERENCES ui_auth_sessions (session_id)
);

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import attr import attr
@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default) return serverdict.get(key, default)
async def add_user_agent_ip_to_ui_auth_session(
self, session_id: str, user_agent: str, ip: str,
):
"""Add the given user agent / IP to the tracking table
"""
await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
values={},
desc="add_user_agent_ip_to_ui_auth_session",
)
async def get_user_agents_ips_to_ui_auth_session(
self, session_id: str,
) -> List[Tuple[str, str]]:
"""Get the given user agents / IPs used during the ui auth process
Returns:
List of user_agent/ip pairs
"""
rows = await self.db_pool.simple_select_list(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id},
retcols=("user_agent", "ip"),
desc="get_user_agents_ips_to_ui_auth_session",
)
return [(row["user_agent"], row["ip"]) for row in rows]
class UIAuthStore(UIAuthWorkerStore): class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int): def delete_old_ui_auth_sessions(self, expiration_time: int):
@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore):
txn.execute(sql, [expiration_time]) txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()] session_ids = [r[0] for r in txn.fetchall()]
# Delete the corresponding IP/user agents.
self.db_pool.simple_delete_many_txn(
txn,
table="ui_auth_sessions_ips",
column="session_id",
iterable=session_ids,
keyvalues={},
)
# Delete the corresponding completed credentials. # Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
txn, txn,

View file

@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock() self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(spec=["args", "getCookie", "addCookie"]) request = Mock(
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
)
code = "code" code = "code"
state = "state" state = "state"
nonce = "nonce" nonce = "nonce"
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token( session = self.handler._generate_oidc_session_token(
state=state, state=state,
nonce=nonce, nonce=nonce,
@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")] request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")]
request.requestHeaders = Mock(spec=["getRawHeaders"])
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with( self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called() self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called() self.handler._render_error.assert_not_called()
@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called() self.handler._parse_id_token.assert_not_called()
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) self.handler._map_userinfo_to_user.assert_called_once_with(
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called() self.handler._render_error.assert_not_called()

View file

@ -17,18 +17,21 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
from .. import unittest from .. import unittest
class RegistrationHandlers(object): class RegistrationHandlers:
def __init__(self, hs): def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs) self.registration_handler = RegistrationHandler(hs)
@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError self.handler.register_user(localpart=invalid_user_id), SynapseError
) )
def test_spam_checker_deny(self):
"""A spam checker can deny registration, which results in an error."""
class DenyAll:
def check_registration_for_spam(
self, email_threepid, username, request_info
):
return RegistrationBehaviour.DENY
# Configure a spam checker that denies all users.
spam_checker = self.hs.get_spam_checker()
spam_checker.spam_checkers = [DenyAll()]
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
def test_spam_checker_shadow_ban(self):
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
class BanAll:
def check_registration_for_spam(
self, email_threepid, username, request_info
):
return RegistrationBehaviour.SHADOW_BAN
# Configure a spam checker that denies all users.
spam_checker = self.hs.get_spam_checker()
spam_checker.spam_checkers = [BanAll()]
user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token.
token = self.macaroon_generator.generate_access_token(user_id)
self.get_success(
self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
)
# Ensure the user was marked as shadow-banned.
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
auth = Auth(self.hs)
requester = self.get_success(auth.get_user_by_req(request))
self.assertTrue(requester.shadow_banned)
async def get_or_create_user( async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None self, requester, localpart, displayname, password_hash=None
): ):

View file

@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self): def test_spam_checker(self):
""" """
A user which fails to the spam checks will not appear in search results. A user which fails the spam checks will not appear in search results.
""" """
u1 = self.register_user("user1", "pass") u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass") u1_token = self.login(u1, "pass")
@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users. # Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker() spam_checker = self.hs.get_spam_checker()
class AllowAll(object): class AllowAll:
def check_username_for_spam(self, user_profile): def check_username_for_spam(self, user_profile):
# Allow all users. # Allow all users.
return False return False
@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1) self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users. # Configure a spam checker that filters all users.
class BlockAll(object): class BlockAll:
def check_username_for_spam(self, user_profile): def check_username_for_spam(self, user_profile):
# All users are spammy. # All users are spammy.
return True return True