0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-27 15:08:18 +02:00

add new error type ResourceLimit

This commit is contained in:
Neil Johnson 2018-08-16 18:02:02 +01:00
parent c151b32b1d
commit 13ad9930c8
8 changed files with 47 additions and 24 deletions

View file

@ -25,7 +25,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -784,10 +784,11 @@ class Auth(object):
MAU cohort MAU cohort
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise ResourceLimitError(
403, self.hs.config.hs_disabled_message, 403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEED, errcode=Codes.RESOURCE_LIMIT_EXCEED,
admin_uri=self.hs.config.admin_uri, admin_uri=self.hs.config.admin_uri,
limit_type=self.hs.config.hs_disabled_limit_type
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort # If the user is already part of the MAU cohort
@ -798,8 +799,9 @@ class Auth(object):
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise ResourceLimitError(
403, "Monthly Active User Limits AU Limit Exceeded", 403, "Monthly Active User Limits AU Limit Exceeded",
admin_uri=self.hs.config.admin_uri, admin_uri=self.hs.config.admin_uri,
errcode=Codes.RESOURCE_LIMIT_EXCEED errcode=Codes.RESOURCE_LIMIT_EXCEED,
limit_type="monthly_active_user"
) )

View file

@ -224,15 +224,34 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
super(AuthError, self).__init__(*args, **kwargs)
class ResourceLimitError(SynapseError):
"""
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
def __init__(
self, code, msg,
errcode=Codes.RESOURCE_LIMIT_EXCEED,
admin_uri=None,
limit_type=None,
):
self.admin_uri = admin_uri self.admin_uri = admin_uri
super(AuthError, self).__init__(code, msg, errcode=errcode) self.limit_type = limit_type
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
def error_dict(self): def error_dict(self):
return cs_error( return cs_error(
self.msg, self.msg,
self.errcode, self.errcode,
admin_uri=self.admin_uri, admin_uri=self.admin_uri,
limit_type=self.limit_type
) )

View file

@ -81,6 +81,7 @@ class ServerConfig(Config):
# Options to disable HS # Options to disable HS
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
# Admin uri to direct users at should their instance become blocked # Admin uri to direct users at should their instance become blocked
# due to resource constraints # due to resource constraints

View file

@ -21,7 +21,7 @@ from twisted.internet import defer
import synapse.handlers.auth import synapse.handlers.auth
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_hs_disabled(self): def test_hs_disabled(self):
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.api.errors import synapse.api.errors
from synapse.api.errors import AuthError from synapse.api.errors import ResourceLimitError
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id('user_a') yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )

View file

@ -17,7 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
with self.assertRaises(AuthError): with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID from synapse.types import UserID
@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase):
# Test that global lock works # Test that global lock works
self.hs.config.hs_disabled = True self.hs.config.hs_disabled = True
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase):
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(AuthError) as e: with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config) yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)

View file

@ -137,6 +137,7 @@ def setup_test_homeserver(
config.limit_usage_by_mau = False config.limit_usage_by_mau = False
config.hs_disabled = False config.hs_disabled = False
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.hs_disabled_limit_type = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_uri = None config.admin_uri = None