# Copyright 2015, 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import contextlib import logging import threading import typing from typing import ( Any, Callable, ContextManager, DefaultDict, Dict, Iterator, List, Mapping, MutableSet, Optional, Set, Tuple, ) from weakref import WeakSet from prometheus_client.core import Counter from twisted.internet import defer from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import FederationRatelimitSettings from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) from synapse.logging.opentracing import start_active_span from synapse.metrics import Histogram, LaterGauge from synapse.util import Clock if typing.TYPE_CHECKING: from contextlib import _GeneratorContextManager logger = logging.getLogger(__name__) # Track how much the ratelimiter is affecting requests rate_limit_sleep_counter = Counter( "synapse_rate_limit_sleep", "Number of requests slept by the rate limiter", ["rate_limiter_name"], ) rate_limit_reject_counter = Counter( "synapse_rate_limit_reject", "Number of requests rejected by the rate limiter", ["rate_limiter_name"], ) queue_wait_timer = Histogram( "synapse_rate_limit_queue_wait_time_seconds", "Amount of time spent waiting for the rate limiter to let our request through.", ["rate_limiter_name"], buckets=( 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 10.0, 20.0, "+Inf", ), ) # This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s # during trial test runs and leak a lot of memory. _rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet() # Protects the _rate_limiter_instances set from concurrent access _rate_limiter_instances_lock = threading.Lock() def _get_counts_from_rate_limiter_instance( count_func: Callable[["FederationRateLimiter"], int] ) -> Mapping[Tuple[str, ...], int]: """Returns a count of something (slept/rejected hosts) by (metrics_name)""" # Cast to a list to prevent it changing while the Prometheus # thread is collecting metrics with _rate_limiter_instances_lock: rate_limiter_instances = list(_rate_limiter_instances) # Map from (metrics_name,) -> int, the number of something like slept hosts # or rejected hosts. The key type is Tuple[str], but we leave the length # unspecified for compatability with LaterGauge's annotations. counts: Dict[Tuple[str, ...], int] = {} for rate_limiter_instance in rate_limiter_instances: # Only track metrics if they provided a `metrics_name` to # differentiate this instance of the rate limiter. if rate_limiter_instance.metrics_name: key = (rate_limiter_instance.metrics_name,) counts[key] = count_func(rate_limiter_instance) return counts # We track the number of affected hosts per time-period so we can # differentiate one really noisy homeserver from a general # ratelimit tuning problem across the federation. LaterGauge( "synapse_rate_limit_sleep_affected_hosts", "Number of hosts that had requests put to sleep", ["rate_limiter_name"], lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_sleep() for ratelimiter in rate_limiter_instance.ratelimiters.values() ) ), ) LaterGauge( "synapse_rate_limit_reject_affected_hosts", "Number of hosts that had requests rejected", ["rate_limiter_name"], lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_reject() for ratelimiter in rate_limiter_instance.ratelimiters.values() ) ), ) class FederationRateLimiter: """Used to rate limit request per-host.""" def __init__( self, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: clock config metrics_name: The name of the rate limiter so we can differentiate it from the rest in the metrics. If `None`, we don't track metrics for this rate limiter. """ self.metrics_name = metrics_name def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( clock=clock, config=config, metrics_name=metrics_name ) self.ratelimiters: DefaultDict[ str, "_PerHostRatelimiter" ] = collections.defaultdict(new_limiter) with _rate_limiter_instances_lock: _rate_limiter_instances.add(self) def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host Example usage: with rate_limiter.ratelimit(origin) as wait_deferred: yield wait_deferred # Handle request ... Args: host: Origin of incoming request. Returns: context manager which returns a deferred. """ return self.ratelimiters[host].ratelimit(host) class _PerHostRatelimiter: def __init__( self, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, ): """ Args: clock config metrics_name: The name of the rate limiter so we can differentiate it from the rest in the metrics. If `None`, we don't track metrics for this rate limiter. from the rest in the metrics """ self.clock = clock self.metrics_name = metrics_name self.window_size = config.window_size self.sleep_limit = config.sleep_limit self.sleep_sec = config.sleep_delay / 1000.0 self.reject_limit = config.reject_limit self.concurrent_requests = config.concurrent # request_id objects for requests which have been slept self.sleeping_requests: Set[object] = set() # map from request_id object to Deferred for requests which are ready # for processing but have been queued self.ready_request_queue: collections.OrderedDict[ object, defer.Deferred[None] ] = collections.OrderedDict() # request id objects for requests which are in progress self.current_processing: Set[object] = set() # times at which we have recently (within the last window_size ms) # received requests. self.request_times: List[int] = [] @contextlib.contextmanager def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. # Exceptions will be reraised at the yield. self.host = host request_id = object() # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant # type-checking, but we'd need Twisted >= 21.2. ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id)) try: yield ret finally: self._on_exit(request_id) def should_reject(self) -> bool: """ Whether to reject the request if we already have too many queued up (either sleeping or in the ready queue). """ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) return queue_size > self.reject_limit def should_sleep(self) -> bool: """ Whether to sleep the request if we already have too many requests coming through within the window. """ return len(self.request_times) > self.sleep_limit async def _on_enter_with_tracing(self, request_id: object) -> None: maybe_metrics_cm: ContextManager = contextlib.nullcontext() if self.metrics_name: maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time() with start_active_span("ratelimit wait"), maybe_metrics_cm: await self._on_enter(request_id) def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() # remove any entries from request_times which aren't within the window self.request_times[:] = [ r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either # sleeping or in the ready queue). if self.should_reject(): logger.debug("Ratelimiter(%s): rejecting request", self.host) if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) def queue_request() -> "defer.Deferred[None]": if len(self.current_processing) >= self.concurrent_requests: queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( "Ratelimiter(%s): queueing request (queue now %i items)", self.host, len(self.ready_request_queue), ) return queue_defer else: return defer.succeed(None) logger.debug( "Ratelimit(%s) [%s]: len(self.request_times)=%d", self.host, id(request_id), len(self.request_times), ) if self.should_sleep(): logger.debug( "Ratelimiter(%s) [%s]: sleeping request for %f sec", self.host, id(request_id), self.sleep_sec, ) if self.metrics_name: rate_limit_sleep_counter.labels(self.metrics_name).inc() ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_: Any) -> "defer.Deferred[None]": logger.debug( "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id) ) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer ret_defer.addBoth(on_wait_finished) else: ret_defer = queue_request() def on_start(r: object) -> object: logger.debug( "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id) ) self.current_processing.add(request_id) return r def on_err(r: object) -> object: # XXX: why is this necessary? this is called before we start # processing the request so why would the request be in # current_processing? self.current_processing.discard(request_id) return r def on_both(r: object) -> object: # Ensure that we've properly cleaned up. self.sleeping_requests.discard(request_id) self.ready_request_queue.pop(request_id, None) return r ret_defer.addCallbacks(on_start, on_err) ret_defer.addBoth(on_both) return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id: object) -> None: logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) # When requests complete synchronously, we will recursively start the next # request in the queue. To avoid stack exhaustion, we defer starting the next # request until the next reactor tick. def start_next_request() -> None: # We only remove the completed request from the list when we're about to # start the next one, otherwise we can allow extra requests through. self.current_processing.discard(request_id) try: # start processing the next item on the queue. _, deferred = self.ready_request_queue.popitem(last=False) with PreserveLoggingContext(): deferred.callback(None) except KeyError: pass self.clock.call_later(0.0, start_next_request)