mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-07 13:44:05 +01:00
Raise LimitExceedError when the ratelimiting is throttling requests
This commit is contained in:
parent
780548b577
commit
683596f91e
2 changed files with 29 additions and 11 deletions
|
@ -40,10 +40,13 @@ class CodeMessageException(Exception):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
|
def error_dict(self):
|
||||||
|
return cs_error(self.msg)
|
||||||
|
|
||||||
|
|
||||||
class SynapseError(CodeMessageException):
|
class SynapseError(CodeMessageException):
|
||||||
"""A base error which can be caught for all synapse events."""
|
"""A base error which can be caught for all synapse events."""
|
||||||
def __init__(self, code, msg, errcode=""):
|
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
||||||
"""Constructs a synapse error.
|
"""Constructs a synapse error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -54,6 +57,11 @@ class SynapseError(CodeMessageException):
|
||||||
super(SynapseError, self).__init__(code, msg)
|
super(SynapseError, self).__init__(code, msg)
|
||||||
self.errcode = errcode
|
self.errcode = errcode
|
||||||
|
|
||||||
|
def error_dict(self):
|
||||||
|
return cs_error(
|
||||||
|
self.msg,
|
||||||
|
self.errcode,
|
||||||
|
)
|
||||||
|
|
||||||
class RoomError(SynapseError):
|
class RoomError(SynapseError):
|
||||||
"""An error raised when a room event fails."""
|
"""An error raised when a room event fails."""
|
||||||
|
@ -92,13 +100,25 @@ class StoreError(SynapseError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def cs_exception(exception):
|
class LimitExceededError(SynapseError):
|
||||||
if isinstance(exception, SynapseError):
|
"""A client has sent too many requests and is being throttled.
|
||||||
|
"""
|
||||||
|
def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
|
||||||
|
errcode=Codes.LIMIT_EXCEEDED):
|
||||||
|
super(LimitExceededError, self).__init__(code, msg, errcode)
|
||||||
|
self.retry_after_ms = retry_after_ms
|
||||||
|
|
||||||
|
def error_dict(self):
|
||||||
return cs_error(
|
return cs_error(
|
||||||
exception.msg,
|
self.msg,
|
||||||
Codes.UNKNOWN if not exception.errcode else exception.errcode)
|
self.errcode,
|
||||||
elif isinstance(exception, CodeMessageException):
|
retry_after_ms=self.retry_after_ms,
|
||||||
return cs_error(exception.msg)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cs_exception(exception):
|
||||||
|
if isinstance(exception, CodeMessageException):
|
||||||
|
return exception.error_dict()
|
||||||
else:
|
else:
|
||||||
logging.error("Unknown exception type: %s", type(exception))
|
logging.error("Unknown exception type: %s", type(exception))
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.api.errors import cs_error, Codes
|
from synapse.api.errors import LimitExceededError
|
||||||
|
|
||||||
class BaseHandler(object):
|
class BaseHandler(object):
|
||||||
|
|
||||||
|
@ -38,9 +38,7 @@ class BaseHandler(object):
|
||||||
burst_count=self.hs.config.rc_message_burst_count,
|
burst_count=self.hs.config.rc_message_burst_count,
|
||||||
)
|
)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise cs_error(
|
raise LimitExceededError(
|
||||||
"Limit exceeded",
|
|
||||||
Codes.LIMIT_EXCEEDED,
|
|
||||||
retry_after_ms=1000*(time_allowed - time_now),
|
retry_after_ms=1000*(time_allowed - time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue