Ratelimit 3PID /requestToken API (#9238)

This commit is contained in:
Erik Johnston 2021-01-28 17:39:21 +00:00 committed by GitHub
parent 54a6afeee3
commit 4b73488e81
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 159 additions and 14 deletions

1
changelog.d/9238.feature Normal file
View file

@ -0,0 +1 @@
Add ratelimited to 3PID `/requestToken` API.

View file

@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs # users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which # "remote" for when users are trying to join rooms not on the server (which
# can be more expensive) # can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote: # remote:
# per_second: 0.01 # per_second: 0.01
# burst_count: 3 # burst_count: 3
#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation
# #

View file

@ -54,7 +54,7 @@ class RootConfig:
tls: tls.TlsConfig tls: tls.TlsConfig
database: database.DatabaseConfig database: database.DatabaseConfig
logging: logger.LoggingConfig logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig captcha: captcha.CaptchaConfig
voip: voip.VoipConfig voip: voip.VoipConfig

View file

@ -24,7 +24,7 @@ class RateLimitConfig:
defaults={"per_second": 0.17, "burst_count": 3.0}, defaults={"per_second": 0.17, "burst_count": 3.0},
): ):
self.per_second = config.get("per_second", defaults["per_second"]) self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"]) self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
class FederationRateLimitConfig: class FederationRateLimitConfig:
@ -102,6 +102,11 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3}, defaults={"per_second": 0.01, "burst_count": 3},
) )
self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
)
def generate_config_section(self, **kwargs): def generate_config_section(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##
@ -131,6 +136,7 @@ class RatelimitConfig(Config):
# users are joining rooms the server is already in (this is cheap) vs # users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which # "remote" for when users are trying to join rooms not on the server (which
# can be more expensive) # can be more expensive)
# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
# #
# The defaults are as shown below. # The defaults are as shown below.
# #
@ -164,7 +170,10 @@ class RatelimitConfig(Config):
# remote: # remote:
# per_second: 0.01 # per_second: 0.01
# burst_count: 3 # burst_count: 3
#
#rc_3pid_validation:
# per_second: 0.003
# burst_count: 5
# Ratelimiting settings for incoming federation # Ratelimiting settings for incoming federation
# #

View file

@ -27,9 +27,11 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester from synapse.types import JsonDict, Requester
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.hash import sha256_and_url_safe_base64
@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
self._web_client_location = hs.config.invite_client_location self._web_client_location = hs.config.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
Args:
request: The associated request
medium: The type of threepid, e.g. "msisdn" or "email"
address: The actual threepid ID, e.g. the phone number or email address
"""
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
async def threepid_from_creds( async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str] self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]: ) -> Optional[JsonDict]:

View file

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$") PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
# The email will be sent to the stored address. # The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to # This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after # an email address which is controlled by the attacker but which, after
@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
if next_link: if next_link:
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)
@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
super().__init__() super().__init__()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
if next_link: if next_link:
# Raise if the provided next_link value isn't valid # Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link) assert_valid_next_link(self.hs, next_link)

View file

@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(request, "email", email)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email "email", email
) )
@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn "msisdn", msisdn
) )

View file

@ -24,7 +24,7 @@ import pkg_resources
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType, Membership from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password # Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password) self.attempt_wrong_password_login("kermit", old_password)
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
"""Test that we ratelimit /requestToken for the same email.
"""
old_password = "monkey"
new_password = "kangeroo"
user_id = self.register_user("kermit", old_password)
self.login("kermit", old_password)
email = "test1@example.com"
# Add a threepid
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
medium="email",
address=email,
validated_at=0,
added_at=0,
)
)
def reset(ip):
client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip)
self.assertEquals(len(self.email_attempts), 1)
link = self._get_link_from_email()
self._validate_token(link)
self._reset_password(new_password, session_id, client_secret)
self.email_attempts.clear()
# We expect to be able to make three requests before getting rate
# limited.
#
# We change IPs to ensure that we're not being ratelimited due to the
# same IP
reset("127.0.0.1")
reset("127.0.0.2")
reset("127.0.0.3")
with self.assertRaises(HttpResponseException) as cm:
reset("127.0.0.4")
self.assertEqual(cm.exception.code, 429)
def test_basic_password_reset_canonicalise_email(self): def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow """Test basic password reset flow
Request password reset with different spelling Request password reset with different spelling
@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id) self.assertIsNotNone(session_id)
def _request_token(self, email, client_secret): def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
b"account/password/email/requestToken", b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1}, {"client_secret": client_secret, "email": email, "send_attempt": 1},
client_ip=ip,
) )
self.assertEquals(200, channel.code, channel.result)
if channel.code != 200:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)
return channel.json_body["sid"] return channel.json_body["sid"]
@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self): def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
"""Tests that adding emails is ratelimited by IP
"""
# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
with self.assertRaises(HttpResponseException) as cm:
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
self.assertEqual(cm.exception.code, 429)
def test_add_email_if_disabled(self): def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed """Test adding email to profile when doing so is disallowed
""" """
@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
body["next_link"] = next_link body["next_link"] = next_link
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
self.assertEquals(expect_code, channel.code, channel.result)
if channel.code != expect_code:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"],
)
return channel.json_body.get("sid") return channel.json_body.get("sid")
@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email): def _add_email(self, request_email, expected_email):
"""Test adding an email to profile """Test adding an email to profile
""" """
previous_email_attempts = len(self.email_attempts)
client_secret = "foobar" client_secret = "foobar"
session_id = self._request_token(request_email, client_secret) session_id = self._request_token(request_email, client_secret)
self.assertEquals(len(self.email_attempts), 1) self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email() link = self._get_link_from_email()
self._validate_token(link) self._validate_token(link)
@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
self.assertIn(expected_email, threepids)

View file

@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site) site = attr.ib(type=Site)
_reactor = attr.ib() _reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict)) result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
_producer = None _producer = None
@property @property
@ -120,7 +121,7 @@ class FakeChannel:
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address("TCP", "127.0.0.1", 3423) return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self): def getHost(self):
return None return None
@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[ custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None, ] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """
Make a web request using the given method, path and content, and render it Make a web request using the given method, path and content, and render it
@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request will pump the reactor until the the renderer tells the channel the request
is finished. is finished.
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
Returns: Returns:
channel channel
""" """
@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf8") content = content.encode("utf8")
channel = FakeChannel(site, reactor) channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel) req = request(channel)
req.content = BytesIO(content) req.content = BytesIO(content)

View file

@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
custom_headers: Optional[ custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None, ] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """
Create a SynapseRequest at the path using the method and containing the Create a SynapseRequest at the path using the method and containing the
@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
custom_headers: (name, value) pairs to add as request headers custom_headers: (name, value) pairs to add as request headers
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
Returns: Returns:
The FakeChannel object which stores the result of the request. The FakeChannel object which stores the result of the request.
""" """
@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
content_is_form, content_is_form,
await_result, await_result,
custom_headers, custom_headers,
client_ip,
) )
def setup_test_homeserver(self, *args, **kwargs): def setup_test_homeserver(self, *args, **kwargs):

View file

@ -157,6 +157,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000}, "local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000},
}, },
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False, "saml2_enabled": False,
"default_identity_server": None, "default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000, "key_refresh_interval": 24 * 60 * 60 * 1000,