forked from MirrorHub/synapse
do mau checks based on monthly_active_users table
This commit is contained in:
parent
9180061b49
commit
74b1d46ad9
7 changed files with 97 additions and 61 deletions
|
@ -773,3 +773,16 @@ class Auth(object):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_auth_blocking(self, error):
|
||||||
|
"""Checks if the user should be rejected for some external reason,
|
||||||
|
such as monthly active user limiting or global disable flag
|
||||||
|
Args:
|
||||||
|
error (Error): The error that should be raised if user is to be
|
||||||
|
blocked
|
||||||
|
"""
|
||||||
|
if self.hs.config.limit_usage_by_mau is True:
|
||||||
|
current_mau = yield self.store.get_monthly_active_count()
|
||||||
|
if current_mau >= self.hs.config.max_mau_value:
|
||||||
|
raise error
|
||||||
|
|
|
@ -913,12 +913,10 @@ class AuthHandler(BaseHandler):
|
||||||
Ensure that if mau blocking is enabled that invalid users cannot
|
Ensure that if mau blocking is enabled that invalid users cannot
|
||||||
log in.
|
log in.
|
||||||
"""
|
"""
|
||||||
if self.hs.config.limit_usage_by_mau is True:
|
error = AuthError(
|
||||||
current_mau = yield self.store.count_monthly_users()
|
403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||||
if current_mau >= self.hs.config.max_mau_value:
|
|
||||||
raise AuthError(
|
|
||||||
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
|
||||||
)
|
)
|
||||||
|
yield self.auth.check_auth_blocking(error)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
|
|
@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
Do not accept registrations if monthly active user limits exceeded
|
Do not accept registrations if monthly active user limits exceeded
|
||||||
and limiting is enabled
|
and limiting is enabled
|
||||||
"""
|
"""
|
||||||
if self.hs.config.limit_usage_by_mau is True:
|
error = RegistrationError(
|
||||||
current_mau = yield self.store.count_monthly_users()
|
403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||||
if current_mau >= self.hs.config.max_mau_value:
|
|
||||||
raise RegistrationError(
|
|
||||||
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
|
|
||||||
)
|
)
|
||||||
|
yield self.auth.check_auth_blocking(error)
|
||||||
|
|
|
@ -97,21 +97,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _populate_monthly_active_users(self, user_id):
|
def _populate_monthly_active_users(self, user_id):
|
||||||
|
"""Checks on the state of monthly active user limits and optionally
|
||||||
|
add the user to the monthly active tables
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id(str): the user_id to query
|
||||||
|
"""
|
||||||
|
|
||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
print "entering _populate_monthly_active_users"
|
|
||||||
if self.hs.config.limit_usage_by_mau:
|
if self.hs.config.limit_usage_by_mau:
|
||||||
print "self.hs.config.limit_usage_by_mau is TRUE"
|
|
||||||
is_user_monthly_active = yield store.is_user_monthly_active(user_id)
|
is_user_monthly_active = yield store.is_user_monthly_active(user_id)
|
||||||
print "is_user_monthly_active is %r" % is_user_monthly_active
|
|
||||||
if is_user_monthly_active:
|
if is_user_monthly_active:
|
||||||
yield store.upsert_monthly_active_user(user_id)
|
yield store.upsert_monthly_active_user(user_id)
|
||||||
else:
|
else:
|
||||||
count = yield store.get_monthly_active_count()
|
count = yield store.get_monthly_active_count()
|
||||||
print "count is %d" % count
|
|
||||||
if count < self.hs.config.max_mau_value:
|
if count < self.hs.config.max_mau_value:
|
||||||
print "count is less than self.hs.config.max_mau_value "
|
yield store.upsert_monthly_active_user(user_id)
|
||||||
res = yield store.upsert_monthly_active_user(user_id)
|
|
||||||
print "upsert response is %r" % res
|
|
||||||
|
|
||||||
def _update_client_ips_batch(self):
|
def _update_client_ips_batch(self):
|
||||||
def update():
|
def update():
|
||||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.handlers.auth
|
import synapse.handlers.auth
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError, Codes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -444,3 +444,32 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
|
||||||
|
|
||||||
self.store.get_user_by_id.assert_called_with(USER_ID)
|
self.store.get_user_by_id.assert_called_with(USER_ID)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_blocking_mau(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = False
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
lots_of_users = 100
|
||||||
|
small_number_of_users = 1
|
||||||
|
|
||||||
|
error = AuthError(
|
||||||
|
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure no error thrown
|
||||||
|
yield self.auth.check_auth_blocking(error)
|
||||||
|
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
|
self.store.get_monthly_active_count = Mock(
|
||||||
|
return_value=defer.succeed(lots_of_users)
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(AuthError):
|
||||||
|
yield self.auth.check_auth_blocking(error)
|
||||||
|
|
||||||
|
# Ensure does not throw an error
|
||||||
|
self.store.get_monthly_active_count = Mock(
|
||||||
|
return_value=defer.succeed(small_number_of_users)
|
||||||
|
)
|
||||||
|
yield self.auth.check_auth_blocking(error)
|
||||||
|
|
|
@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_mau_limits_exceeded(self):
|
def test_mau_limits_exceeded(self):
|
||||||
self.hs.config.limit_usage_by_mau = True
|
self.hs.config.limit_usage_by_mau = True
|
||||||
self.hs.get_datastore().count_monthly_users = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.large_number_of_users)
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(AuthError):
|
||||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
self.hs.get_datastore().count_monthly_users = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.large_number_of_users)
|
return_value=defer.succeed(self.large_number_of_users)
|
||||||
)
|
)
|
||||||
with self.assertRaises(AuthError):
|
with self.assertRaises(AuthError):
|
||||||
|
@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_mau_limits_not_exceeded(self):
|
def test_mau_limits_not_exceeded(self):
|
||||||
self.hs.config.limit_usage_by_mau = True
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
|
||||||
self.hs.get_datastore().count_monthly_users = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.small_number_of_users)
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
)
|
)
|
||||||
# Ensure does not raise exception
|
# Ensure does not raise exception
|
||||||
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
yield self.auth_handler.get_access_token_for_user_id('user_a')
|
||||||
|
|
||||||
self.hs.get_datastore().count_monthly_users = Mock(
|
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||||
return_value=defer.succeed(self.small_number_of_users)
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
)
|
)
|
||||||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
|
|
|
@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
|
self.store = self.hs.get_datastore()
|
||||||
|
self.hs.config.max_mau_value = 50
|
||||||
|
self.lots_of_users = 100
|
||||||
|
self.small_number_of_users = 1
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
@ -80,51 +84,44 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cannot_register_when_mau_limits_exceeded(self):
|
def test_mau_limits_when_disabled(self):
|
||||||
local_part = "someone"
|
|
||||||
display_name = "someone"
|
|
||||||
requester = create_requester("@as:test")
|
|
||||||
store = self.hs.get_datastore()
|
|
||||||
self.hs.config.limit_usage_by_mau = False
|
self.hs.config.limit_usage_by_mau = False
|
||||||
self.hs.config.max_mau_value = 50
|
|
||||||
lots_of_users = 100
|
|
||||||
small_number_users = 1
|
|
||||||
|
|
||||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
|
||||||
|
|
||||||
# Ensure does not throw exception
|
# Ensure does not throw exception
|
||||||
yield self.handler.get_or_create_user(requester, 'a', display_name)
|
yield self.handler.get_or_create_user("requester", 'a', "display_name")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_or_create_user_mau_not_blocked(self):
|
||||||
self.hs.config.limit_usage_by_mau = True
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
self.store.count_monthly_users = Mock(
|
||||||
with self.assertRaises(RegistrationError):
|
return_value=defer.succeed(self.small_number_of_users)
|
||||||
yield self.handler.get_or_create_user(requester, 'b', display_name)
|
)
|
||||||
|
|
||||||
store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
|
|
||||||
|
|
||||||
self._macaroon_mock_generator("another_secret")
|
|
||||||
|
|
||||||
# Ensure does not throw exception
|
# Ensure does not throw exception
|
||||||
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
|
yield self.handler.get_or_create_user("@user:server", 'c', "User")
|
||||||
|
|
||||||
self._macaroon_mock_generator("another another secret")
|
@defer.inlineCallbacks
|
||||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
def test_get_or_create_user_mau_blocked(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
self.store.get_monthly_active_count = Mock(
|
||||||
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaises(RegistrationError):
|
with self.assertRaises(RegistrationError):
|
||||||
yield self.handler.register(localpart=local_part)
|
yield self.handler.get_or_create_user("requester", 'b', "display_name")
|
||||||
|
|
||||||
self._macaroon_mock_generator("another another secret")
|
|
||||||
store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_register_mau_blocked(self):
|
||||||
|
self.hs.config.limit_usage_by_mau = True
|
||||||
|
self.store.get_monthly_active_count = Mock(
|
||||||
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
|
)
|
||||||
with self.assertRaises(RegistrationError):
|
with self.assertRaises(RegistrationError):
|
||||||
yield self.handler.register_saml2(local_part)
|
yield self.handler.register(localpart="local_part")
|
||||||
|
|
||||||
def _macaroon_mock_generator(self, secret):
|
@defer.inlineCallbacks
|
||||||
"""
|
def test_register_saml2_mau_blocked(self):
|
||||||
Reset macaroon generator in the case where the test creates multiple users
|
self.hs.config.limit_usage_by_mau = True
|
||||||
"""
|
self.store.get_monthly_active_count = Mock(
|
||||||
macaroon_generator = Mock(
|
return_value=defer.succeed(self.lots_of_users)
|
||||||
generate_access_token=Mock(return_value=secret))
|
)
|
||||||
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
|
with self.assertRaises(RegistrationError):
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
yield self.handler.register_saml2(localpart="local_part")
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
|
||||||
|
|
Loading…
Reference in a new issue