0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 23:33:53 +01:00

Check whether to ratelimit sooner to avoid work

This commit is contained in:
Erik Johnston 2016-10-19 10:45:24 +01:00
parent 50ac1d843d
commit 550308c7a1
2 changed files with 22 additions and 6 deletions

View file

@ -23,7 +23,7 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
"""Can the user send a message? """Can the user send a message?
Args: Args:
user_id: The user sending a message. user_id: The user sending a message.
@ -32,12 +32,15 @@ class Ratelimiter(object):
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message. time in seconds of when they can next send a message.
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop( message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@ -52,9 +55,10 @@ class Ratelimiter(object):
allowed = True allowed = True
message_count += 1 message_count += 1
self.message_counts[user_id] = ( if update:
message_count, time_start, msg_rate_hz self.message_counts[user_id] = (
) message_count, time_start, msg_rate_hz
)
if msg_rate_hz > 0: if msg_rate_hz > 0:
time_allowed = ( time_allowed = (

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -239,6 +239,18 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath" "Tried to send member event through non-member codepath"
) )
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)