Merge pull request #2208 from matrix-org/erikj/ratelimit_overrid

Add per user ratelimiting overrides
This commit is contained in:
Erik Johnston 2017-05-10 15:54:48 +01:00 committed by GitHub
commit a3648f84b2
6 changed files with 93 additions and 19 deletions

View file

@ -53,7 +53,20 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, requester): @defer.inlineCallbacks
def ratelimit(self, requester, update=True):
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
Raises:
LimitExceededError if the request should be ratelimited
"""
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -67,10 +80,25 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited(): if requester.app_service and not requester.app_service.is_rate_limited():
return return
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
# If overriden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
messages_per_second = self.hs.config.rc_messages_per_second
burst_count = self.hs.config.rc_message_burst_count
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=burst_count,
update=update,
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -254,17 +254,7 @@ class MessageHandler(BaseHandler):
# We check here if we are currently being rate limited, so that we # We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually # don't do unnecessary work. We check again just before we actually
# send the event. # send the event.
time_now = self.clock.time() yield self.ratelimit(requester, update=False)
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
@ -499,7 +489,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit: if ratelimit:
self.ratelimit(requester) yield self.ratelimit(requester)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)

View file

@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
self.ratelimit(requester) yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user( room_ids = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),

View file

@ -75,7 +75,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
self.ratelimit(requester) yield self.ratelimit(requester)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace: for wchar in string.whitespace:

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine from .engines import PostgresEngine, Sqlite3Engine
@ -33,6 +33,11 @@ OpsLevel = collections.namedtuple(
("ban_level", "kick_level", "redact_level",) ("ban_level", "kick_level", "redact_level",)
) )
RatelimitOverride = collections.namedtuple(
"RatelimitOverride",
("messages_per_second", "burst_count",)
)
class RoomStore(SQLBaseStore): class RoomStore(SQLBaseStore):
@ -473,3 +478,32 @@ class RoomStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
Args:
user_id (str)
Returns:
RatelimitOverride if there is an override, else None. If the contents
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
row = yield self._simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
allow_none=True,
desc="get_ratelimit_for_user",
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
else:
defer.returnValue(None)

View file

@ -0,0 +1,22 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE ratelimit_override (
user_id TEXT NOT NULL,
messages_per_second BIGINT,
burst_count BIGINT
);
CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);