0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-09-29 13:08:57 +02:00

Merge branch 'ratelimiting' into develop

This commit is contained in:
Mark Haines 2014-09-03 09:15:52 +01:00
commit 30ad0c5674
14 changed files with 244 additions and 10 deletions

View file

@ -28,6 +28,7 @@ class Codes(object):
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
class CodeMessageException(Exception): class CodeMessageException(Exception):
@ -39,10 +40,13 @@ class CodeMessageException(Exception):
self.code = code self.code = code
self.msg = msg self.msg = msg
def error_dict(self):
return cs_error(self.msg)
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base error which can be caught for all synapse events.""" """A base error which can be caught for all synapse events."""
def __init__(self, code, msg, errcode=""): def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error. """Constructs a synapse error.
Args: Args:
@ -53,6 +57,11 @@ class SynapseError(CodeMessageException):
super(SynapseError, self).__init__(code, msg) super(SynapseError, self).__init__(code, msg)
self.errcode = errcode self.errcode = errcode
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
)
class RoomError(SynapseError): class RoomError(SynapseError):
"""An error raised when a room event fails.""" """An error raised when a room event fails."""
@ -91,13 +100,25 @@ class StoreError(SynapseError):
pass pass
def cs_exception(exception): class LimitExceededError(SynapseError):
if isinstance(exception, SynapseError): """A client has sent too many requests and is being throttled.
"""
def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
return cs_error( return cs_error(
exception.msg, self.msg,
Codes.UNKNOWN if not exception.errcode else exception.errcode) self.errcode,
elif isinstance(exception, CodeMessageException): retry_after_ms=self.retry_after_ms,
return cs_error(exception.msg) )
def cs_exception(exception):
if isinstance(exception, CodeMessageException):
return exception.error_dict()
else: else:
logging.error("Unknown exception type: %s", type(exception)) logging.error("Unknown exception type: %s", type(exception))

View file

@ -0,0 +1,65 @@
import collections
class Ratelimiter(object):
"""
Ratelimit message sending by user.
"""
def __init__(self):
self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
"""Can the user send a message?
Args:
user_id: The user sending a message.
time_now_s: The time now.
msg_rate_hz: The long term number of messages a user can send in a
second.
burst_count: How many messages the user can send before being
limited.
Returns:
A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message.
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop(
user_id, (0., time_now_s, None),
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * msg_rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
messagecount = 1.
elif sent_count > burst_count - 1.:
allowed = False
else:
allowed = True
message_count += 1
self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz
)
if msg_rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz
)
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
time_allowed = -1
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys():
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0:
break
else:
del self.message_counts[user_id]

View file

@ -247,6 +247,7 @@ def setup():
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config,
) )
hs.register_servlets() hs.register_servlets()

View file

@ -17,8 +17,10 @@ from .tls import TlsConfig
from .server import ServerConfig from .server import ServerConfig
from .logger import LoggingConfig from .logger import LoggingConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig): class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig):
pass pass
if __name__=='__main__': if __name__=='__main__':

View file

@ -0,0 +1,21 @@
from ._base import Config
class RatelimitConfig(Config):
def __init__(self, args):
super(RatelimitConfig, self).__init__(args)
self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count
@classmethod
def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser)
rc_group = parser.add_argument_group("ratelimiting")
rc_group.add_argument(
"--rc-messages-per-second", type=float, default=0.2,
help="number of messages a client can send per second"
)
rc_group.add_argument(
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError
class BaseHandler(object): class BaseHandler(object):
@ -25,8 +26,22 @@ class BaseHandler(object):
self.room_lock = hs.get_room_lock_manager() self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs self.hs = hs
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=1000*(time_allowed - time_now),
)
class BaseRoomHandler(BaseHandler): class BaseRoomHandler(BaseHandler):

View file

@ -76,6 +76,8 @@ class MessageHandler(BaseRoomHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
self.ratelimit(event.user_id)
# TODO(paul): Why does 'event' not have a 'user' object? # TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,) assert user.is_mine, "User must be our own: %s" % (user,)

View file

@ -49,6 +49,7 @@ class RoomCreationHandler(BaseRoomHandler):
SynapseError if the room ID was taken, couldn't be stored, or SynapseError if the room ID was taken, couldn't be stored, or
something went horribly wrong. something went horribly wrong.
""" """
self.ratelimit(user_id)
if "room_alias_name" in config: if "room_alias_name" in config:
room_alias = RoomAlias.create_local( room_alias = RoomAlias.create_local(

View file

@ -32,6 +32,7 @@ from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
class BaseHomeServer(object): class BaseHomeServer(object):
@ -73,6 +74,7 @@ class BaseHomeServer(object):
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'event_sources', 'event_sources',
'ratelimiter',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer):
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
def build_ratelimiter(self):
return Ratelimiter()
def register_servlets(self): def register_servlets(self):
""" Register all servlets associated with this HomeServer. """ Register all servlets associated with this HomeServer.
""" """

0
tests/api/__init__.py Normal file
View file

View file

@ -0,0 +1,39 @@
from synapse.api.ratelimiting import Ratelimiter
import unittest
class TestRatelimiter(unittest.TestCase):
def test_allowed(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertTrue(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
)
self.assertFalse(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
self.assertEquals(20., time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertIn("test_id_1", limiter.message_counts)
allowed, time_allowed = limiter.send_message(
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertNotIn("test_id_1", limiter.message_counts)

View file

@ -39,6 +39,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
hs = HomeServer( hs = HomeServer(
self.hostname, self.hostname,
db_pool=None, db_pool=None,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
datastore=NonCallableMock(spec_set=[ datastore=NonCallableMock(spec_set=[
"persist_event", "persist_event",
"get_joined_hosts_for_room", "get_joined_hosts_for_room",
@ -82,6 +86,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.snapshot = Mock() self.snapshot = Mock()
self.datastore.snapshot_room.return_value = self.snapshot self.datastore.snapshot_room.return_value = self.snapshot
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):
@ -342,6 +348,10 @@ class RoomCreationTest(unittest.TestCase):
]), ]),
auth=NonCallableMock(spec_set=["check"]), auth=NonCallableMock(spec_set=["check"]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]), state_handler=NonCallableMock(spec_set=["handle_new_event"]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.federation = NonCallableMock(spec_set=[ self.federation = NonCallableMock(spec_set=[
@ -368,6 +378,9 @@ class RoomCreationTest(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts self.datastore.get_joined_hosts_for_room.side_effect = hosts
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_creation(self): def test_room_creation(self):
user_id = "@foo:red" user_id = "@foo:red"

View file

@ -32,7 +32,7 @@ import logging
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock from mock import Mock, NonCallableMock
logging.getLogger().addHandler(logging.NullHandler()) logging.getLogger().addHandler(logging.NullHandler())
@ -136,8 +136,15 @@ class EventStreamPermissionsTestCase(RestTestCase):
"call_later", "call_later",
"cancel_call_later", "cancel_call_later",
"time_msec", "time_msec",
"time"
]), ]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View file

@ -30,7 +30,7 @@ import urllib
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock from mock import Mock, NonCallableMock
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
@ -58,7 +58,14 @@ class RoomPermissionsTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_token(token=None):
@ -405,7 +412,14 @@ class RoomsMemberListTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
@ -483,7 +497,14 @@ class RoomsCreateTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_token(token=None):
@ -573,7 +594,14 @@ class RoomTopicTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_token(token=None):
@ -676,7 +704,14 @@ class RoomMemberStateTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_token(token=None):
@ -801,7 +836,14 @@ class RoomMessagesTestCase(RestTestCase):
replication_layer=Mock(), replication_layer=Mock(),
state_handler=state_handler, state_handler=state_handler,
persistence_service=persistence_service, persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
) )
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_token(token=None):