0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 10:33:49 +01:00

support admin_email config and pass through into blocking errors, return AuthError in all cases

This commit is contained in:
Neil Johnson 2018-08-13 18:00:23 +01:00
parent 99dd975dae
commit 0d43f991a1
7 changed files with 45 additions and 22 deletions

View file

@ -781,11 +781,15 @@ class Auth(object):
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED 403, self.hs.config.hs_disabled_message,
errcode=Codes.HS_DISABLED,
admin_email=self.hs.config.admin_email,
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED 403, "MAU Limit Exceeded",
admin_email=self.hs.config.admin_email,
errcode=Codes.MAU_LIMIT_EXCEEDED
) )

View file

@ -225,11 +225,20 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if "errcode" not in kwargs: if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN kwargs["errcode"] = Codes.FORBIDDEN
super(AuthError, self).__init__(*args, **kwargs) self.admin_email = kwargs.get('admin_email')
self.msg = kwargs.get('msg')
self.errcode = kwargs.get('errcode')
super(AuthError, self).__init__(*args, errcode=kwargs["errcode"])
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
admin_email=self.admin_email,
)
class EventSizeError(SynapseError): class EventSizeError(SynapseError):

View file

@ -82,6 +82,10 @@ class ServerConfig(Config):
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
# Admin email to direct users at should their instance become blocked
# due to resource constraints
self.admin_email = config.get("admin_email", None)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View file

@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
yield self.auth.check_auth_blocking()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
need_register = True need_register = True
try: try:
@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler):
action="join", action="join",
) )
@defer.inlineCallbacks # @defer.inlineCallbacks
def _check_mau_limits(self): # def _s(self):
""" # """
Do not accept registrations if monthly active user limits exceeded # Do not accept registrations if monthly active user limits exceeded
and limiting is enabled # and limiting is enabled
""" # """
try: # try:
yield self.auth.check_auth_blocking() # yield self.auth.check_auth_blocking()
except AuthError as e: # except AuthError as e:
raise RegistrationError(e.code, str(e), e.errcode) # raise RegistrationError(e.code, str(e), e.errcode)

View file

@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_email, self.hs.config.admin_email)
self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error # Ensure does not throw an error
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(AuthError) as e: with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_email, self.hs.config.admin_email)
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)

View file

@ -17,7 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError from synapse.api.errors import AuthError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")

View file

@ -139,6 +139,7 @@ def setup_test_homeserver(
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_email = None
# we need a sane default_room_version, otherwise attempts to create rooms will # we need a sane default_room_version, otherwise attempts to create rooms will
# fail. # fail.