forked from MirrorHub/synapse
Add rate-limiting on registration (#4735)
* Rate-limiting for registration * Add unit test for registration rate limiting * Add config parameters for rate limiting on auth endpoints * Doc * Fix doc of rate limiting function Co-Authored-By: babolivier <contact@brendanabolivier.com> * Incorporate review * Fix config parsing * Fix linting errors * Set default config for auth rate limiting * Fix tests * Add changelog * Advance reactor instead of mocked clock * Move parameters to registration specific config and give them more sensible default values * Remove unused config options * Don't mock the rate limiter un MAU tests * Rename _register_with_store into register_with_store * Make CI happy * Remove unused import * Update sample config * Fix ratelimiting test for py2 * Add non-guest test
This commit is contained in:
parent
3887e0cd80
commit
a4c3a361b7
17 changed files with 186 additions and 54 deletions
1
changelog.d/4735.feature
Normal file
1
changelog.d/4735.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add configurable rate limiting to the /register endpoint.
|
|
@ -657,6 +657,17 @@ trusted_third_party_id_servers:
|
|||
#
|
||||
autocreate_auto_join_rooms: true
|
||||
|
||||
# Number of registration requests a client can send per second.
|
||||
# Defaults to 1/minute (0.17).
|
||||
#
|
||||
#rc_registration_requests_per_second: 0.17
|
||||
|
||||
# Number of registration requests a client can send before being
|
||||
# throttled.
|
||||
# Defaults to 3.
|
||||
#
|
||||
#rc_registration_request_burst_count: 3.0
|
||||
|
||||
|
||||
## Metrics ###
|
||||
|
||||
|
|
|
@ -23,12 +23,13 @@ class Ratelimiter(object):
|
|||
def __init__(self):
|
||||
self.message_counts = collections.OrderedDict()
|
||||
|
||||
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
|
||||
"""Can the user send a message?
|
||||
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
|
||||
"""Can the entity (e.g. user or IP address) perform the action?
|
||||
Args:
|
||||
user_id: The user sending a message.
|
||||
key: The key we should use when rate limiting. Can be a user ID
|
||||
(when sending events), an IP address, etc.
|
||||
time_now_s: The time now.
|
||||
msg_rate_hz: The long term number of messages a user can send in a
|
||||
rate_hz: The long term number of messages a user can send in a
|
||||
second.
|
||||
burst_count: How many messages the user can send before being
|
||||
limited.
|
||||
|
@ -41,10 +42,10 @@ class Ratelimiter(object):
|
|||
"""
|
||||
self.prune_message_counts(time_now_s)
|
||||
message_count, time_start, _ignored = self.message_counts.get(
|
||||
user_id, (0., time_now_s, None),
|
||||
key, (0., time_now_s, None),
|
||||
)
|
||||
time_delta = time_now_s - time_start
|
||||
sent_count = message_count - time_delta * msg_rate_hz
|
||||
sent_count = message_count - time_delta * rate_hz
|
||||
if sent_count < 0:
|
||||
allowed = True
|
||||
time_start = time_now_s
|
||||
|
@ -56,13 +57,13 @@ class Ratelimiter(object):
|
|||
message_count += 1
|
||||
|
||||
if update:
|
||||
self.message_counts[user_id] = (
|
||||
message_count, time_start, msg_rate_hz
|
||||
self.message_counts[key] = (
|
||||
message_count, time_start, rate_hz
|
||||
)
|
||||
|
||||
if msg_rate_hz > 0:
|
||||
if rate_hz > 0:
|
||||
time_allowed = (
|
||||
time_start + (message_count - burst_count + 1) / msg_rate_hz
|
||||
time_start + (message_count - burst_count + 1) / rate_hz
|
||||
)
|
||||
if time_allowed < time_now_s:
|
||||
time_allowed = time_now_s
|
||||
|
@ -72,12 +73,12 @@ class Ratelimiter(object):
|
|||
return allowed, time_allowed
|
||||
|
||||
def prune_message_counts(self, time_now_s):
|
||||
for user_id in list(self.message_counts.keys()):
|
||||
message_count, time_start, msg_rate_hz = (
|
||||
self.message_counts[user_id]
|
||||
for key in list(self.message_counts.keys()):
|
||||
message_count, time_start, rate_hz = (
|
||||
self.message_counts[key]
|
||||
)
|
||||
time_delta = time_now_s - time_start
|
||||
if message_count - time_delta * msg_rate_hz > 0:
|
||||
if message_count - time_delta * rate_hz > 0:
|
||||
break
|
||||
else:
|
||||
del self.message_counts[user_id]
|
||||
del self.message_counts[key]
|
||||
|
|
|
@ -54,6 +54,13 @@ class RegistrationConfig(Config):
|
|||
config.get("disable_msisdn_registration", False)
|
||||
)
|
||||
|
||||
self.rc_registration_requests_per_second = config.get(
|
||||
"rc_registration_requests_per_second", 0.17,
|
||||
)
|
||||
self.rc_registration_request_burst_count = config.get(
|
||||
"rc_registration_request_burst_count", 3,
|
||||
)
|
||||
|
||||
def default_config(self, generate_secrets=False, **kwargs):
|
||||
if generate_secrets:
|
||||
registration_shared_secret = 'registration_shared_secret: "%s"' % (
|
||||
|
@ -140,6 +147,17 @@ class RegistrationConfig(Config):
|
|||
# users cannot be auto-joined since they do not exist.
|
||||
#
|
||||
autocreate_auto_join_rooms: true
|
||||
|
||||
# Number of registration requests a client can send per second.
|
||||
# Defaults to 1/minute (0.17).
|
||||
#
|
||||
#rc_registration_requests_per_second: 0.17
|
||||
|
||||
# Number of registration requests a client can send before being
|
||||
# throttled.
|
||||
# Defaults to 3.
|
||||
#
|
||||
#rc_registration_request_burst_count: 3.0
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
|
|
@ -93,9 +93,9 @@ class BaseHandler(object):
|
|||
messages_per_second = self.hs.config.rc_messages_per_second
|
||||
burst_count = self.hs.config.rc_message_burst_count
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.send_message(
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
user_id, time_now,
|
||||
msg_rate_hz=messages_per_second,
|
||||
rate_hz=messages_per_second,
|
||||
burst_count=burst_count,
|
||||
update=update,
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.errors import (
|
|||
AuthError,
|
||||
Codes,
|
||||
InvalidCaptchaError,
|
||||
LimitExceededError,
|
||||
RegistrationError,
|
||||
SynapseError,
|
||||
)
|
||||
|
@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
|
|||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
self.identity_handler = self.hs.get_handlers().identity_handler
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
|
||||
self._next_generated_user_id = None
|
||||
|
||||
|
@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
|
|||
threepid=None,
|
||||
user_type=None,
|
||||
default_display_name=None,
|
||||
address=None,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
|
@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
|
|||
api.constants.UserTypes, or None for a normal user.
|
||||
default_display_name (unicode|None): if set, the new user's displayname
|
||||
will be set to this. Defaults to 'localpart'.
|
||||
address (str|None): the IP address used to perform the regitration.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
|
@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
|
|||
token = None
|
||||
if generate_token:
|
||||
token = self.macaroon_gen.generate_access_token(user_id)
|
||||
yield self._register_with_store(
|
||||
yield self.register_with_store(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
|
@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
|
|||
create_profile_with_displayname=default_display_name,
|
||||
admin=admin,
|
||||
user_type=user_type,
|
||||
address=address,
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
|
@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
|
|||
if default_display_name is None:
|
||||
default_display_name = localpart
|
||||
try:
|
||||
yield self._register_with_store(
|
||||
yield self.register_with_store(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
make_guest=make_guest,
|
||||
create_profile_with_displayname=default_display_name,
|
||||
address=address,
|
||||
)
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
|
@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
|
|||
user_id, allowed_appservice=service
|
||||
)
|
||||
|
||||
yield self._register_with_store(
|
||||
yield self.register_with_store(
|
||||
user_id=user_id,
|
||||
password_hash="",
|
||||
appservice_id=service_id,
|
||||
|
@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
|
|||
token = self.macaroon_gen.generate_access_token(user_id)
|
||||
|
||||
if need_register:
|
||||
yield self._register_with_store(
|
||||
yield self.register_with_store(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
|
@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
|
||||
def _register_with_store(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_displayname=None, admin=False,
|
||||
user_type=None):
|
||||
def register_with_store(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_displayname=None, admin=False,
|
||||
user_type=None, address=None):
|
||||
"""Register user in the datastore.
|
||||
|
||||
Args:
|
||||
|
@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
|
|||
admin (boolean): is an admin user?
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
address (str|None): the IP address used to perform the regitration.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# Don't rate limit for app services
|
||||
if appservice_id is None and address is not None:
|
||||
time_now = self.clock.time()
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
address, time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration_requests_per_second,
|
||||
burst_count=self.hs.config.rc_registration_request_burst_count,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
)
|
||||
|
||||
if self.hs.config.worker_app:
|
||||
return self._register_client(
|
||||
user_id=user_id,
|
||||
|
@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
|
|||
create_profile_with_displayname=create_profile_with_displayname,
|
||||
admin=admin,
|
||||
user_type=user_type,
|
||||
address=address,
|
||||
)
|
||||
else:
|
||||
return self.store.register(
|
||||
|
|
|
@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
def __init__(self, hs):
|
||||
super(ReplicationRegisterServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(
|
||||
user_id, token, password_hash, was_guest, make_guest, appservice_id,
|
||||
create_profile_with_displayname, admin, user_type,
|
||||
create_profile_with_displayname, admin, user_type, address,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
admin (boolean): is an admin user?
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
address (str|None): the IP address used to perform the regitration.
|
||||
"""
|
||||
return {
|
||||
"token": token,
|
||||
|
@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
"create_profile_with_displayname": create_profile_with_displayname,
|
||||
"admin": admin,
|
||||
"user_type": user_type,
|
||||
"address": address,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_request(self, request, user_id):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
yield self.store.register(
|
||||
yield self.registration_handler.register_with_store(
|
||||
user_id=user_id,
|
||||
token=content["token"],
|
||||
password_hash=content["password_hash"],
|
||||
|
@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
create_profile_with_displayname=content["create_profile_with_displayname"],
|
||||
admin=content["admin"],
|
||||
user_type=content["user_type"],
|
||||
address=content["address"]
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -25,7 +25,12 @@ from twisted.internet import defer
|
|||
import synapse
|
||||
import synapse.types
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
LimitExceededError,
|
||||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
|
@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
|
|||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientIP()
|
||||
|
||||
time_now = self.clock.time()
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
client_addr, time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration_requests_per_second,
|
||||
burst_count=self.hs.config.rc_registration_request_burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
)
|
||||
|
||||
kind = b"user"
|
||||
if b"kind" in request.args:
|
||||
kind = request.args[b"kind"][0]
|
||||
|
||||
if kind == b"guest":
|
||||
ret = yield self._do_guest_registration(body)
|
||||
ret = yield self._do_guest_registration(body, address=client_addr)
|
||||
defer.returnValue(ret)
|
||||
return
|
||||
elif kind != b"user":
|
||||
|
@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
|
|||
guest_access_token=guest_access_token,
|
||||
generate_token=False,
|
||||
threepid=threepid,
|
||||
address=client_addr,
|
||||
)
|
||||
# Necessary due to auth checks prior to the threepid being
|
||||
# written to the db
|
||||
|
@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
|
|||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self, params):
|
||||
def _do_guest_registration(self, params, address=None):
|
||||
if not self.hs.config.allow_guest_access:
|
||||
raise SynapseError(403, "Guest access is disabled")
|
||||
user_id, _ = yield self.registration_handler.register(
|
||||
generate_token=False,
|
||||
make_guest=True
|
||||
make_guest=True,
|
||||
address=address,
|
||||
)
|
||||
|
||||
# we don't allow guests to specify their own device_id, because
|
||||
|
|
|
@ -6,34 +6,34 @@ from tests import unittest
|
|||
class TestRatelimiter(unittest.TestCase):
|
||||
def test_allowed(self):
|
||||
limiter = Ratelimiter()
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||
allowed, time_allowed = limiter.can_do_action(
|
||||
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10., time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
|
||||
allowed, time_allowed = limiter.can_do_action(
|
||||
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10., time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
|
||||
allowed, time_allowed = limiter.can_do_action(
|
||||
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20., time_allowed)
|
||||
|
||||
def test_pruning(self):
|
||||
limiter = Ratelimiter()
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
|
||||
allowed, time_allowed = limiter.can_do_action(
|
||||
key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
|
||||
)
|
||||
|
||||
self.assertIn("test_id_1", limiter.message_counts)
|
||||
|
||||
allowed, time_allowed = limiter.send_message(
|
||||
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
|
||||
allowed, time_allowed = limiter.can_do_action(
|
||||
key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
|
||||
)
|
||||
|
||||
self.assertNotIn("test_id_1", limiter.message_counts)
|
||||
|
|
|
@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
|
|||
federation_client=self.mock_federation,
|
||||
federation_server=Mock(),
|
||||
federation_registry=self.mock_registry,
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||
)
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
|
|
@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
|||
hs = self.setup_test_homeserver(
|
||||
"blue",
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||
)
|
||||
|
||||
hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
hs.get_ratelimiter().can_do_action.return_value = (True, 0)
|
||||
|
||||
return hs
|
||||
|
||||
|
|
|
@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
|
|||
config.auto_join_rooms = []
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
|
||||
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
|
|
|
@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
|
|||
"red",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||
)
|
||||
self.ratelimiter = self.hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||
|
||||
self.hs.get_federation_handler = Mock(return_value=Mock())
|
||||
|
||||
|
@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
|
|||
# auth as user_id now
|
||||
self.helper.auth_user_id = self.user_id
|
||||
|
||||
def test_send_message(self):
|
||||
def test_can_do_action(self):
|
||||
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
||||
|
||||
seq = iter(range(100))
|
||||
|
|
|
@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||
"red",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
|
||||
)
|
||||
|
||||
self.event_source = hs.get_event_sources().sources["typing"]
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
self.ratelimiter.can_do_action.return_value = (True, 0)
|
||||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
|
|
|
@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
||||
|
||||
def test_POST_ratelimiting_guest(self):
|
||||
self.hs.config.rc_registration_request_burst_count = 5
|
||||
|
||||
for i in range(0, 6):
|
||||
url = self.url + b"?kind=guest"
|
||||
request, channel = self.make_request(b"POST", url, b"{}")
|
||||
self.render(request)
|
||||
|
||||
if i == 5:
|
||||
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.)
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
def test_POST_ratelimiting(self):
|
||||
self.hs.config.rc_registration_request_burst_count = 5
|
||||
|
||||
for i in range(0, 6):
|
||||
params = {
|
||||
"username": "kermit" + str(i),
|
||||
"password": "monkey",
|
||||
"device_id": "frogfone",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
if i == 5:
|
||||
self.assertEquals(channel.result["code"], b"429", channel.result)
|
||||
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.)
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import json
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
from mock import Mock
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
|
@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
|||
"red",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
|
|
@ -150,6 +150,8 @@ def default_config(name):
|
|||
config.admin_contact = None
|
||||
config.rc_messages_per_second = 10000
|
||||
config.rc_message_burst_count = 10000
|
||||
config.rc_registration_request_burst_count = 3.0
|
||||
config.rc_registration_requests_per_second = 0.17
|
||||
config.saml2_enabled = False
|
||||
config.public_baseurl = None
|
||||
config.default_identity_server = None
|
||||
|
|
Loading…
Reference in a new issue