0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-05 23:54:05 +01:00

Merge pull request #1176 from matrix-org/erikj/eager_ratelimit_check

Check whether to ratelimit sooner to avoid work
This commit is contained in:
Erik Johnston 2016-10-19 14:25:52 +01:00 committed by GitHub
commit 3aa8925091
2 changed files with 25 additions and 6 deletions

View file

@ -23,7 +23,7 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
"""Can the user send a message? """Can the user send a message?
Args: Args:
user_id: The user sending a message. user_id: The user sending a message.
@ -32,12 +32,15 @@ class Ratelimiter(object):
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message. time in seconds of when they can next send a message.
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop( message_count, time_start, _ignored = self.message_counts.get(
user_id, (0., time_now_s, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@ -52,9 +55,10 @@ class Ratelimiter(object):
allowed = True allowed = True
message_count += 1 message_count += 1
self.message_counts[user_id] = ( if update:
message_count, time_start, msg_rate_hz self.message_counts[user_id] = (
) message_count, time_start, msg_rate_hz
)
if msg_rate_hz > 0: if msg_rate_hz > 0:
time_allowed = ( time_allowed = (

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -239,6 +239,21 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath" "Tried to send member event through non-member codepath"
) )
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)