make count_monthly_users async synapse/handlers/auth.py

This commit is contained in:
Neil Johnson 2018-08-01 10:21:56 +01:00
parent c507fa15ce
commit 7931393495
4 changed files with 46 additions and 38 deletions

View file

@ -144,7 +144,7 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
self._check_mau_limits() yield self._check_mau_limits()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -289,7 +289,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
self._check_mau_limits() yield self._check_mau_limits()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -439,7 +439,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
self._check_mau_limits() yield self._check_mau_limits()
need_register = True need_register = True
try: try:
@ -534,13 +534,14 @@ class RegistrationHandler(BaseHandler):
action="join", action="join",
) )
@defer.inlineCallbacks
def _check_mau_limits(self): def _check_mau_limits(self):
""" """
Do not accept registrations if monthly active user limits exceeded Do not accept registrations if monthly active user limits exceeded
and limiting is enabled and limiting is enabled
""" """
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
current_mau = self.store.count_monthly_users() current_mau = yield self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise RegistrationError( raise RegistrationError(
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED

View file

@ -273,24 +273,24 @@ class DataStore(RoomMemberStore, RoomStore,
This method should be refactored with count_daily_users - the only This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau reason not to is waiting on definition of mau
returns: returns:
int: count of current monthly active users defered: resolves to int
""" """
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT user_id FROM user_ips
WHERE last_seen > ?
GROUP BY user_id
) u
"""
try:
txn = self.db_conn.cursor()
txn.execute(sql, (thirty_days_ago,)) txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone() count, = txn.fetchone()
print "Count is %d" % (count,)
return count return count
finally:
txn.close() return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """

View file

@ -77,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce) v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.assertEqual( token
"a_user",
self.auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
) )
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.clock.now = 6000 self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
@defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
self.assertEqual( self.assertEqual(
"a_user", "a_user", user_id
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
) )
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
@ -116,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user") macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError): with self.assertRaises(synapse.api.errors.AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
@ -126,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
# Ensure does not throw exception # Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@ -134,24 +133,30 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_exceeded(self): def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().count_monthly_users = Mock(
return_value=self.large_number_of_users return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().count_monthly_users = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock( self.hs.get_datastore().count_monthly_users = Mock(
return_value=self.small_number_of_users return_value=defer.succeed(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )

View file

@ -90,7 +90,7 @@ class RegistrationTestCase(unittest.TestCase):
lots_of_users = 100 lots_of_users = 100
small_number_users = 1 small_number_users = 1
store.count_monthly_users = Mock(return_value=lots_of_users) store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name) yield self.handler.get_or_create_user(requester, 'a', display_name)
@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.get_or_create_user(requester, 'b', display_name) yield self.handler.get_or_create_user(requester, 'b', display_name)
store.count_monthly_users = Mock(return_value=small_number_users) store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
self._macaroon_mock_generator("another_secret") self._macaroon_mock_generator("another_secret")
@ -108,12 +108,14 @@ class RegistrationTestCase(unittest.TestCase):
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
self._macaroon_mock_generator("another another secret") self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=lots_of_users) store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part) yield self.handler.register(localpart=local_part)
self._macaroon_mock_generator("another another secret") self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=lots_of_users) store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
with self.assertRaises(RegistrationError): with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part) yield self.handler.register_saml2(local_part)