0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 05:53:51 +01:00

Initial implementation of federation server rate limiting

This commit is contained in:
Erik Johnston 2015-02-26 16:15:26 +00:00
parent a025055643
commit 93d90765c4
2 changed files with 182 additions and 5 deletions

View file

@ -21,7 +21,7 @@ support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer from .server import TransportLayerServer, FederationRateLimiter
from .client import TransportLayerClient from .client import TransportLayerClient
@ -55,8 +55,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=10000,
sleep_limit=10,
sleep_msec=500,
reject_limit=50,
concurrent_requests=3,
)

View file

@ -16,9 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError, LimitExceededError
from synapse.util.async import sleep
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import collections
import logging import logging
import simplejson as json import simplejson as json
import re import re
@ -27,6 +29,163 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.ratelimiters = {}
def ratelimit(self, host):
return self.ratelimiters.setdefault(
host,
PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
)
).ratelimit()
class PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.sleeping_requests = set()
self.ready_request_queue = collections.OrderedDict()
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
def ratelimit(self):
request_id = object()
def on_enter():
return self._on_enter(request_id)
def on_exit(exc_type, exc_val, exc_tb):
return self._on_exit(request_id)
return ContextManagerFunction(on_enter, on_exit)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
)
self.request_times.append(time_now)
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
return queue_defer
else:
return defer.succeed(None)
logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times))
logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times))
if len(self.request_times) > self.sleep_limit:
logger.debug("Ratelimit [%s]: sleeping req", id(request_id))
ret_defer = sleep(self.sleep_msec/1000.0)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r):
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
def on_err(r):
self.current_processing.discard(request_id)
return r
def on_both(r):
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return ret_defer
def _on_exit(self, request_id):
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
self.current_processing.add(request_id)
deferred.callback(None)
except KeyError:
pass
class ContextManagerFunction(object):
def __init__(self, on_enter, on_exit):
self.on_enter = on_enter
self.on_exit = on_exit
def __enter__(self):
if self.on_enter:
return self.on_enter()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.on_exit:
return self.on_exit(exc_type, exc_val, exc_tb)
class TransportLayerServer(object): class TransportLayerServer(object):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
@ -98,6 +257,8 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler( response = yield handler(
origin, content, request.args, *args, **kwargs origin, content, request.args, *args, **kwargs
) )
@ -107,6 +268,12 @@ class TransportLayerServer(object):
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.