mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 17:43:16 +01:00
parent
d9f44fd0b9
commit
78b5102ae7
3 changed files with 125 additions and 24 deletions
1
changelog.d/10078.misc
Normal file
1
changelog.d/10078.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Fix up `BatchingQueue` implementation.
|
|
@ -25,10 +25,11 @@ from typing import (
|
|||
TypeVar,
|
||||
)
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -38,6 +39,24 @@ logger = logging.getLogger(__name__)
|
|||
V = TypeVar("V")
|
||||
R = TypeVar("R")
|
||||
|
||||
number_queued = Gauge(
|
||||
"synapse_util_batching_queue_number_queued",
|
||||
"The number of items waiting in the queue across all keys",
|
||||
labelnames=("name",),
|
||||
)
|
||||
|
||||
number_in_flight = Gauge(
|
||||
"synapse_util_batching_queue_number_pending",
|
||||
"The number of items across all keys either being processed or waiting in a queue",
|
||||
labelnames=("name",),
|
||||
)
|
||||
|
||||
number_of_keys = Gauge(
|
||||
"synapse_util_batching_queue_number_of_keys",
|
||||
"The number of distinct keys that have items queued",
|
||||
labelnames=("name",),
|
||||
)
|
||||
|
||||
|
||||
class BatchingQueue(Generic[V, R]):
|
||||
"""A queue that batches up work, calling the provided processing function
|
||||
|
@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
|
|||
called, and will keep being called until the queue has been drained (for the
|
||||
given key).
|
||||
|
||||
If the processing function raises an exception then the exception is proxied
|
||||
through to the callers waiting on that batch of work.
|
||||
|
||||
Note that the return value of `add_to_queue` will be the return value of the
|
||||
processing function that processed the given item. This means that the
|
||||
returned value will likely include data for other items that were in the
|
||||
batch.
|
||||
|
||||
Args:
|
||||
name: A name for the queue, used for logging contexts and metrics.
|
||||
This must be unique, otherwise the metrics will be wrong.
|
||||
clock: The clock to use to schedule work.
|
||||
process_batch_callback: The callback to to be run to process a batch of
|
||||
work.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]):
|
|||
# The function to call with batches of values.
|
||||
self._process_batch_callback = process_batch_callback
|
||||
|
||||
LaterGauge(
|
||||
"synapse_util_batching_queue_number_queued",
|
||||
"The number of items waiting in the queue across all keys",
|
||||
labels=("name",),
|
||||
caller=lambda: sum(len(v) for v in self._next_values.values()),
|
||||
number_queued.labels(self._name).set_function(
|
||||
lambda: sum(len(q) for q in self._next_values.values())
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_util_batching_queue_number_of_keys",
|
||||
"The number of distinct keys that have items queued",
|
||||
labels=("name",),
|
||||
caller=lambda: len(self._next_values),
|
||||
)
|
||||
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
|
||||
|
||||
self._number_in_flight_metric = number_in_flight.labels(
|
||||
self._name
|
||||
) # type: Gauge
|
||||
|
||||
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||
"""Adds the value to the queue with the given key, returning the result
|
||||
|
@ -107,17 +132,18 @@ class BatchingQueue(Generic[V, R]):
|
|||
if key not in self._processing_keys:
|
||||
run_as_background_process(self._name, self._process_queue, key)
|
||||
|
||||
return await make_deferred_yieldable(d)
|
||||
with self._number_in_flight_metric.track_inprogress():
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
async def _process_queue(self, key: Hashable) -> None:
|
||||
"""A background task to repeatedly pull things off the queue for the
|
||||
given key and call the `self._process_batch_callback` with the values.
|
||||
"""
|
||||
|
||||
try:
|
||||
if key in self._processing_keys:
|
||||
return
|
||||
if key in self._processing_keys:
|
||||
return
|
||||
|
||||
try:
|
||||
self._processing_keys.add(key)
|
||||
|
||||
while True:
|
||||
|
@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]):
|
|||
values = [value for value, _ in next_values]
|
||||
results = await self._process_batch_callback(values)
|
||||
|
||||
for _, deferred in next_values:
|
||||
with PreserveLoggingContext():
|
||||
with PreserveLoggingContext():
|
||||
for _, deferred in next_values:
|
||||
deferred.callback(results)
|
||||
|
||||
except Exception as e:
|
||||
for _, deferred in next_values:
|
||||
if deferred.called:
|
||||
continue
|
||||
with PreserveLoggingContext():
|
||||
for _, deferred in next_values:
|
||||
if deferred.called:
|
||||
continue
|
||||
|
||||
with PreserveLoggingContext():
|
||||
deferred.errback(e)
|
||||
|
||||
finally:
|
||||
|
|
|
@ -14,7 +14,12 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util.batching_queue import BatchingQueue
|
||||
from synapse.util.batching_queue import (
|
||||
BatchingQueue,
|
||||
number_in_flight,
|
||||
number_of_keys,
|
||||
number_queued,
|
||||
)
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
|
|||
def setUp(self):
|
||||
self.clock, hs_clock = get_clock()
|
||||
|
||||
# We ensure that we remove any existing metrics for "test_queue".
|
||||
try:
|
||||
number_queued.remove("test_queue")
|
||||
number_of_keys.remove("test_queue")
|
||||
number_in_flight.remove("test_queue")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
self._pending_calls = []
|
||||
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
|
||||
|
||||
|
@ -32,6 +45,41 @@ class BatchingQueueTestCase(TestCase):
|
|||
self._pending_calls.append((values, d))
|
||||
return await make_deferred_yieldable(d)
|
||||
|
||||
def _assert_metrics(self, queued, keys, in_flight):
|
||||
"""Assert that the metrics are correct"""
|
||||
|
||||
self.assertEqual(len(number_queued.collect()), 1)
|
||||
self.assertEqual(len(number_queued.collect()[0].samples), 1)
|
||||
self.assertEqual(
|
||||
number_queued.collect()[0].samples[0].labels,
|
||||
{"name": self.queue._name},
|
||||
)
|
||||
self.assertEqual(
|
||||
number_queued.collect()[0].samples[0].value,
|
||||
queued,
|
||||
"number_queued",
|
||||
)
|
||||
|
||||
self.assertEqual(len(number_of_keys.collect()), 1)
|
||||
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
|
||||
self.assertEqual(
|
||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
||||
)
|
||||
self.assertEqual(
|
||||
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
|
||||
)
|
||||
|
||||
self.assertEqual(len(number_in_flight.collect()), 1)
|
||||
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
|
||||
self.assertEqual(
|
||||
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
|
||||
)
|
||||
self.assertEqual(
|
||||
number_in_flight.collect()[0].samples[0].value,
|
||||
in_flight,
|
||||
"number_in_flight",
|
||||
)
|
||||
|
||||
def test_simple(self):
|
||||
"""Tests the basic case of calling `add_to_queue` once and having
|
||||
`_process_queue` return.
|
||||
|
@ -41,6 +89,8 @@ class BatchingQueueTestCase(TestCase):
|
|||
|
||||
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
|
||||
|
||||
self._assert_metrics(queued=1, keys=1, in_flight=1)
|
||||
|
||||
# The queue should wait a reactor tick before calling the processing
|
||||
# function.
|
||||
self.assertFalse(self._pending_calls)
|
||||
|
@ -52,12 +102,15 @@ class BatchingQueueTestCase(TestCase):
|
|||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo"])
|
||||
self.assertFalse(queue_d.called)
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=1)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back.
|
||||
self._pending_calls.pop()[1].callback("bar")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d), "bar")
|
||||
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_batching(self):
|
||||
"""Test that multiple calls at the same time get batched up into one
|
||||
call to `_process_queue`.
|
||||
|
@ -68,6 +121,8 @@ class BatchingQueueTestCase(TestCase):
|
|||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||
|
||||
self._assert_metrics(queued=2, keys=1, in_flight=2)
|
||||
|
||||
self.clock.pump([0])
|
||||
|
||||
# We should see only *one* call to `_process_queue`
|
||||
|
@ -75,12 +130,14 @@ class BatchingQueueTestCase(TestCase):
|
|||
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
|
||||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=2)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to both.
|
||||
self._pending_calls.pop()[1].callback("bar")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d1), "bar")
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar")
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_queuing(self):
|
||||
"""Test that we queue up requests while a `_process_queue` is being
|
||||
|
@ -92,13 +149,20 @@ class BatchingQueueTestCase(TestCase):
|
|||
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
|
||||
self.clock.pump([0])
|
||||
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
|
||||
# We queue up work after the process function has been called, testing
|
||||
# that they get correctly queued up.
|
||||
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
|
||||
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))
|
||||
|
||||
# We should see only *one* call to `_process_queue`
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo1"])
|
||||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=2, keys=1, in_flight=3)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# first.
|
||||
|
@ -106,18 +170,24 @@ class BatchingQueueTestCase(TestCase):
|
|||
|
||||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=2, keys=1, in_flight=2)
|
||||
|
||||
# We should now see a second call to `_process_queue`
|
||||
self.clock.pump([0])
|
||||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo2"])
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=2)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# second.
|
||||
self._pending_calls.pop()[1].callback("bar2")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d2), "bar2")
|
||||
self.assertEqual(self.successResultOf(queue_d3), "bar2")
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
||||
def test_different_keys(self):
|
||||
"""Test that calls to different keys get processed in parallel."""
|
||||
|
@ -140,6 +210,7 @@ class BatchingQueueTestCase(TestCase):
|
|||
self.assertFalse(queue_d1.called)
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=1, keys=1, in_flight=3)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# first.
|
||||
|
@ -148,6 +219,7 @@ class BatchingQueueTestCase(TestCase):
|
|||
self.assertEqual(self.successResultOf(queue_d1), "bar1")
|
||||
self.assertFalse(queue_d2.called)
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=1, keys=1, in_flight=2)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# second.
|
||||
|
@ -161,9 +233,11 @@ class BatchingQueueTestCase(TestCase):
|
|||
self.assertEqual(len(self._pending_calls), 1)
|
||||
self.assertEqual(self._pending_calls[0][0], ["foo3"])
|
||||
self.assertFalse(queue_d3.called)
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=1)
|
||||
|
||||
# Return value of the `_process_queue` should be propagated back to the
|
||||
# third deferred.
|
||||
self._pending_calls.pop()[1].callback("bar4")
|
||||
|
||||
self.assertEqual(self.successResultOf(queue_d3), "bar4")
|
||||
self._assert_metrics(queued=0, keys=0, in_flight=0)
|
||||
|
|
Loading…
Reference in a new issue