forked from MirrorHub/synapse
do mau checks based on monthly_active_users table
This commit is contained in:
parent
9180061b49
commit
74b1d46ad9
7 changed files with 97 additions and 61 deletions
|
@ -773,3 +773,16 @@ class Auth(object):
|
|||
raise AuthError(
|
||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth_blocking(self, error):
|
||||
"""Checks if the user should be rejected for some external reason,
|
||||
such as monthly active user limiting or global disable flag
|
||||
Args:
|
||||
error (Error): The error that should be raised if user is to be
|
||||
blocked
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise error
|
||||
|
|
|
@ -913,12 +913,10 @@ class AuthHandler(BaseHandler):
|
|||
Ensure that if mau blocking is enabled that invalid users cannot
|
||||
log in.
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
current_mau = yield self.store.count_monthly_users()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise AuthError(
|
||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
error = AuthError(
|
||||
403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
)
|
||||
yield self.auth.check_auth_blocking(error)
|
||||
|
||||
|
||||
@attr.s
|
||||
|
|
|
@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler):
|
|||
Do not accept registrations if monthly active user limits exceeded
|
||||
and limiting is enabled
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
current_mau = yield self.store.count_monthly_users()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise RegistrationError(
|
||||
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
|
||||
error = RegistrationError(
|
||||
403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
)
|
||||
yield self.auth.check_auth_blocking(error)
|
||||
|
|
|
@ -97,21 +97,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_monthly_active_users(self, user_id):
|
||||
"""Checks on the state of monthly active user limits and optionally
|
||||
add the user to the monthly active tables
|
||||
|
||||
Args:
|
||||
user_id(str): the user_id to query
|
||||
"""
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
print "entering _populate_monthly_active_users"
|
||||
if self.hs.config.limit_usage_by_mau:
|
||||
print "self.hs.config.limit_usage_by_mau is TRUE"
|
||||
is_user_monthly_active = yield store.is_user_monthly_active(user_id)
|
||||
print "is_user_monthly_active is %r" % is_user_monthly_active
|
||||
if is_user_monthly_active:
|
||||
yield store.upsert_monthly_active_user(user_id)
|
||||
else:
|
||||
count = yield store.get_monthly_active_count()
|
||||
print "count is %d" % count
|
||||
if count < self.hs.config.max_mau_value:
|
||||
print "count is less than self.hs.config.max_mau_value "
|
||||
res = yield store.upsert_monthly_active_user(user_id)
|
||||
print "upsert response is %r" % res
|
||||
yield store.upsert_monthly_active_user(user_id)
|
||||
|
||||
def _update_client_ips_batch(self):
|
||||
def update():
|
||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
|||
|
||||
import synapse.handlers.auth
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.errors import AuthError, Codes
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
|
@ -444,3 +444,32 @@ class AuthTestCase(unittest.TestCase):
|
|||
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
||||
|
||||
self.store.get_user_by_id.assert_called_with(USER_ID)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_blocking_mau(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.hs.config.max_mau_value = 50
|
||||
lots_of_users = 100
|
||||
small_number_of_users = 1
|
||||
|
||||
error = AuthError(
|
||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
# Ensure no error thrown
|
||||
yield self.auth.check_auth_blocking(error)
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(lots_of_users)
|
||||
)
|
||||
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.auth.check_auth_blocking(error)
|
||||
|
||||
# Ensure does not throw an error
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(small_number_of_users)
|
||||
)
|
||||
yield self.auth.check_auth_blocking(error)
|
||||
|
|
|
@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_mau_limits_exceeded(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.large_number_of_users)
|
||||
)
|
||||
|
||||
with self.assertRaises(AuthError):
|
||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.large_number_of_users)
|
||||
)
|
||||
with self.assertRaises(AuthError):
|
||||
|
@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
|
|||
def test_mau_limits_not_exceeded(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
)
|
||||
# Ensure does not raise exception
|
||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||
|
||||
self.hs.get_datastore().count_monthly_users = Mock(
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
)
|
||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
|
|
|
@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
self.store = self.hs.get_datastore()
|
||||
self.hs.config.max_mau_value = 50
|
||||
self.lots_of_users = 100
|
||||
self.small_number_of_users = 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
|
@ -80,51 +84,44 @@ class RegistrationTestCase(unittest.TestCase):
|
|||
self.assertEquals(result_token, 'secret')
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cannot_register_when_mau_limits_exceeded(self):
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
requester = create_requester("@as:test")
|
||||
store = self.hs.get_datastore()
|
||||
def test_mau_limits_when_disabled(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.hs.config.max_mau_value = 50
|
||||
lots_of_users = 100
|
||||
small_number_users = 1
|
||||
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
|
||||
# Ensure does not throw exception
|
||||
yield self.handler.get_or_create_user(requester, 'a', display_name)
|
||||
yield self.handler.get_or_create_user("requester", 'a', "display_name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_or_create_user_mau_not_blocked(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.get_or_create_user(requester, 'b', display_name)
|
||||
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
|
||||
|
||||
self._macaroon_mock_generator("another_secret")
|
||||
|
||||
self.store.count_monthly_users = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
)
|
||||
# Ensure does not throw exception
|
||||
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
|
||||
yield self.handler.get_or_create_user("@user:server", 'c', "User")
|
||||
|
||||
self._macaroon_mock_generator("another another secret")
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
@defer.inlineCallbacks
|
||||
def test_get_or_create_user_mau_blocked(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.register(localpart=local_part)
|
||||
|
||||
self._macaroon_mock_generator("another another secret")
|
||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
||||
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_register_mau_blocked(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.register_saml2(local_part)
|
||||
yield self.handler.register(localpart="local_part")
|
||||
|
||||
def _macaroon_mock_generator(self, secret):
|
||||
"""
|
||||
Reset macaroon generator in the case where the test creates multiple users
|
||||
"""
|
||||
macaroon_generator = Mock(
|
||||
generate_access_token=Mock(return_value=secret))
|
||||
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
@defer.inlineCallbacks
|
||||
def test_register_saml2_mau_blocked(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.lots_of_users)
|
||||
)
|
||||
with self.assertRaises(RegistrationError):
|
||||
yield self.handler.register_saml2(localpart="local_part")
|
||||
|
|
Loading…
Reference in a new issue