0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-02 02:38:59 +02:00

Add rate-limiting on registration (#4735)

* Rate-limiting for registration

* Add unit test for registration rate limiting

* Add config parameters for rate limiting on auth endpoints

* Doc

* Fix doc of rate limiting function

Co-Authored-By: babolivier <contact@brendanabolivier.com>

* Incorporate review

* Fix config parsing

* Fix linting errors

* Set default config for auth rate limiting

* Fix tests

* Add changelog

* Advance reactor instead of mocked clock

* Move parameters to registration specific config and give them more sensible default values

* Remove unused config options

* Don't mock the rate limiter un MAU tests

* Rename _register_with_store into register_with_store

* Make CI happy

* Remove unused import

* Update sample config

* Fix ratelimiting test for py2

* Add non-guest test
This commit is contained in:
Brendan Abolivier 2019-03-05 14:25:33 +00:00 committed by GitHub
parent 3887e0cd80
commit a4c3a361b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 186 additions and 54 deletions

1
changelog.d/4735.feature Normal file
View file

@ -0,0 +1 @@
Add configurable rate limiting to the /register endpoint.

View file

@ -657,6 +657,17 @@ trusted_third_party_id_servers:
#
autocreate_auto_join_rooms: true
# Number of registration requests a client can send per second.
# Defaults to 1/minute (0.17).
#
#rc_registration_requests_per_second: 0.17
# Number of registration requests a client can send before being
# throttled.
# Defaults to 3.
#
#rc_registration_request_burst_count: 3.0
## Metrics ###

View file

@ -23,12 +23,13 @@ class Ratelimiter(object):
def __init__(self):
self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
"""Can the user send a message?
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
"""Can the entity (e.g. user or IP address) perform the action?
Args:
user_id: The user sending a message.
key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
time_now_s: The time now.
msg_rate_hz: The long term number of messages a user can send in a
rate_hz: The long term number of messages a user can send in a
second.
burst_count: How many messages the user can send before being
limited.
@ -41,10 +42,10 @@ class Ratelimiter(object):
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None),
key, (0., time_now_s, None),
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * msg_rate_hz
sent_count = message_count - time_delta * rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
@ -56,13 +57,13 @@ class Ratelimiter(object):
message_count += 1
if update:
self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz
self.message_counts[key] = (
message_count, time_start, rate_hz
)
if msg_rate_hz > 0:
if rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz
time_start + (message_count - burst_count + 1) / rate_hz
)
if time_allowed < time_now_s:
time_allowed = time_now_s
@ -72,12 +73,12 @@ class Ratelimiter(object):
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
for key in list(self.message_counts.keys()):
message_count, time_start, rate_hz = (
self.message_counts[key]
)
time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0:
if message_count - time_delta * rate_hz > 0:
break
else:
del self.message_counts[user_id]
del self.message_counts[key]

View file

@ -54,6 +54,13 @@ class RegistrationConfig(Config):
config.get("disable_msisdn_registration", False)
)
self.rc_registration_requests_per_second = config.get(
"rc_registration_requests_per_second", 0.17,
)
self.rc_registration_request_burst_count = config.get(
"rc_registration_request_burst_count", 3,
)
def default_config(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@ -140,6 +147,17 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist.
#
autocreate_auto_join_rooms: true
# Number of registration requests a client can send per second.
# Defaults to 1/minute (0.17).
#
#rc_registration_requests_per_second: 0.17
# Number of registration requests a client can send before being
# throttled.
# Defaults to 3.
#
#rc_registration_request_burst_count: 3.0
""" % locals()
def add_arguments(self, parser):

View file

@ -93,9 +93,9 @@ class BaseHandler(object):
messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count
allowed, time_allowed = self.ratelimiter.send_message(
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now,
msg_rate_hz=messages_per_second,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)

View file

@ -24,6 +24,7 @@ from synapse.api.errors import (
AuthError,
Codes,
InvalidCaptchaError,
LimitExceededError,
RegistrationError,
SynapseError,
)
@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_ratelimiter()
self._next_generated_user_id = None
@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
threepid=None,
user_type=None,
default_display_name=None,
address=None,
):
"""Registers a new client on the server.
@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the regitration.
Returns:
A tuple of (user_id, access_token).
Raises:
@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
token = None
if generate_token:
token = self.macaroon_gen.generate_access_token(user_id)
yield self._register_with_store(
yield self.register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name,
admin=admin,
user_type=user_type,
address=address,
)
if self.hs.config.user_directory_search_all_users:
@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
if default_display_name is None:
default_display_name = localpart
try:
yield self._register_with_store(
yield self.register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
address=address,
)
except SynapseError:
# if user id is taken, just generate another
@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
yield self._register_with_store(
yield self.register_with_store(
user_id=user_id,
password_hash="",
appservice_id=service_id,
@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self._register_with_store(
yield self.register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
def _register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None):
def register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None, address=None):
"""Register user in the datastore.
Args:
@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
Returns:
Deferred
"""
# Don't rate limit for app services
if appservice_id is None and address is not None:
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
address, time_now_s=time_now,
rate_hz=self.hs.config.rc_registration_requests_per_second,
burst_count=self.hs.config.rc_registration_request_burst_count,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
address=address,
)
else:
return self.store.register(

View file

@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
def __init__(self, hs):
super(ReplicationRegisterServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
@staticmethod
def _serialize_payload(
user_id, token, password_hash, was_guest, make_guest, appservice_id,
create_profile_with_displayname, admin, user_type,
create_profile_with_displayname, admin, user_type, address,
):
"""
Args:
@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
"""
return {
"token": token,
@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"create_profile_with_displayname": create_profile_with_displayname,
"admin": admin,
"user_type": user_type,
"address": address,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
yield self.store.register(
yield self.registration_handler.register_with_store(
user_id=user_id,
token=content["token"],
password_hash=content["password_hash"],
@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
address=content["address"]
)
defer.returnValue((200, {}))

View file

@ -25,7 +25,12 @@ from twisted.internet import defer
import synapse
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.api.errors import (
Codes,
LimitExceededError,
SynapseError,
UnrecognizedRequestError,
)
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_ratelimiter()
self.clock = hs.get_clock()
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr, time_now_s=time_now,
rate_hz=self.hs.config.rc_registration_requests_per_second,
burst_count=self.hs.config.rc_registration_request_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
kind = b"user"
if b"kind" in request.args:
kind = request.args[b"kind"][0]
if kind == b"guest":
ret = yield self._do_guest_registration(body)
ret = yield self._do_guest_registration(body, address=client_addr)
defer.returnValue(ret)
return
elif kind != b"user":
@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
guest_access_token=guest_access_token,
generate_token=False,
threepid=threepid,
address=client_addr,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
defer.returnValue(result)
@defer.inlineCallbacks
def _do_guest_registration(self, params):
def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
make_guest=True,
address=address,
)
# we don't allow guests to specify their own device_id, because

View file

@ -6,34 +6,34 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase):
def test_allowed(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
)
self.assertFalse(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
self.assertEquals(20., time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
allowed, time_allowed = limiter.can_do_action(
key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
)
self.assertIn("test_id_1", limiter.message_counts)
allowed, time_allowed = limiter.send_message(
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
allowed, time_allowed = limiter.can_do_action(
key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
)
self.assertNotIn("test_id_1", limiter.message_counts)

View file

@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=["send_message"]),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.ratelimiter.can_do_action.return_value = (True, 0)
self.store = hs.get_datastore()

View file

@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
"blue",
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
hs.get_ratelimiter().send_message.return_value = (True, 0)
hs.get_ratelimiter().can_do_action.return_value = (True, 0)
return hs

View file

@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config.auto_join_rooms = []
hs = self.setup_test_homeserver(
config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()

View file

@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
"red",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock())
@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
# auth as user_id now
self.helper.auth_user_id = self.user_id
def test_send_message(self):
def test_can_do_action(self):
msg_content = b'{"msgtype":"m.text","body":"hello"}'
seq = iter(range(100))

View file

@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"red",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
self.event_source = hs.get_event_sources().sources["typing"]
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()

View file

@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
def test_POST_ratelimiting_guest(self):
self.hs.config.rc_registration_request_burst_count = 5
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_POST_ratelimiting(self):
self.hs.config.rc_registration_request_burst_count = 5
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

View file

@ -17,7 +17,7 @@
import json
from mock import Mock, NonCallableMock
from mock import Mock
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
"red",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.store = self.hs.get_datastore()

View file

@ -150,6 +150,8 @@ def default_config(name):
config.admin_contact = None
config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000
config.rc_registration_request_burst_count = 3.0
config.rc_registration_requests_per_second = 0.17
config.saml2_enabled = False
config.public_baseurl = None
config.default_identity_server = None