mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-16 23:53:52 +01:00
Speed up sliding sync by computing extensions in parallel (#17884)
The main change here is to add a helper function `gather_optional_coroutines`, which works in a similar way as `yieldable_gather_results` but takes a set of coroutines rather than a function
This commit is contained in:
parent
58deef5eba
commit
83513b75f7
5 changed files with 285 additions and 13 deletions
1
changelog.d/17884.misc
Normal file
1
changelog.d/17884.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Minor speed-up of sliding sync by computing extensions results in parallel.
|
|
@ -49,7 +49,10 @@ from synapse.types.handlers.sliding_sync import (
|
|||
SlidingSyncConfig,
|
||||
SlidingSyncResult,
|
||||
)
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.async_helpers import (
|
||||
concurrently_execute,
|
||||
gather_optional_coroutines,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -97,26 +100,26 @@ class SlidingSyncExtensionHandler:
|
|||
if sync_config.extensions is None:
|
||||
return SlidingSyncResult.Extensions()
|
||||
|
||||
to_device_response = None
|
||||
to_device_coro = None
|
||||
if sync_config.extensions.to_device is not None:
|
||||
to_device_response = await self.get_to_device_extension_response(
|
||||
to_device_coro = self.get_to_device_extension_response(
|
||||
sync_config=sync_config,
|
||||
to_device_request=sync_config.extensions.to_device,
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
e2ee_response = None
|
||||
e2ee_coro = None
|
||||
if sync_config.extensions.e2ee is not None:
|
||||
e2ee_response = await self.get_e2ee_extension_response(
|
||||
e2ee_coro = self.get_e2ee_extension_response(
|
||||
sync_config=sync_config,
|
||||
e2ee_request=sync_config.extensions.e2ee,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
account_data_response = None
|
||||
account_data_coro = None
|
||||
if sync_config.extensions.account_data is not None:
|
||||
account_data_response = await self.get_account_data_extension_response(
|
||||
account_data_coro = self.get_account_data_extension_response(
|
||||
sync_config=sync_config,
|
||||
previous_connection_state=previous_connection_state,
|
||||
new_connection_state=new_connection_state,
|
||||
|
@ -127,9 +130,9 @@ class SlidingSyncExtensionHandler:
|
|||
from_token=from_token,
|
||||
)
|
||||
|
||||
receipts_response = None
|
||||
receipts_coro = None
|
||||
if sync_config.extensions.receipts is not None:
|
||||
receipts_response = await self.get_receipts_extension_response(
|
||||
receipts_coro = self.get_receipts_extension_response(
|
||||
sync_config=sync_config,
|
||||
previous_connection_state=previous_connection_state,
|
||||
new_connection_state=new_connection_state,
|
||||
|
@ -141,9 +144,9 @@ class SlidingSyncExtensionHandler:
|
|||
from_token=from_token,
|
||||
)
|
||||
|
||||
typing_response = None
|
||||
typing_coro = None
|
||||
if sync_config.extensions.typing is not None:
|
||||
typing_response = await self.get_typing_extension_response(
|
||||
typing_coro = self.get_typing_extension_response(
|
||||
sync_config=sync_config,
|
||||
actual_lists=actual_lists,
|
||||
actual_room_ids=actual_room_ids,
|
||||
|
@ -153,6 +156,20 @@ class SlidingSyncExtensionHandler:
|
|||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
account_data_response,
|
||||
receipts_response,
|
||||
typing_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
account_data_coro,
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
to_device=to_device_response,
|
||||
e2ee=e2ee_response,
|
||||
|
|
|
@ -37,6 +37,7 @@ import warnings
|
|||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
|
@ -850,6 +851,45 @@ def run_in_background(
|
|||
return d
|
||||
|
||||
|
||||
def run_coroutine_in_background(
|
||||
coroutine: typing.Coroutine[Any, Any, R],
|
||||
) -> "defer.Deferred[R]":
|
||||
"""Run the coroutine, ensuring that the current context is restored after
|
||||
return from the function, and that the sentinel context is set once the
|
||||
deferred returned by the function completes.
|
||||
|
||||
Useful for wrapping coroutines that you don't yield or await on (for
|
||||
instance because you want to pass it to deferred.gatherResults()).
|
||||
|
||||
This is a special case of `run_in_background` where we can accept a
|
||||
coroutine directly rather than a function. We can do this because coroutines
|
||||
do not run until called, and so calling an async function without awaiting
|
||||
cannot change the log contexts.
|
||||
"""
|
||||
|
||||
current = current_context()
|
||||
d = defer.ensureDeferred(coroutine)
|
||||
|
||||
# The function may have reset the context before returning, so
|
||||
# we need to restore it now.
|
||||
ctx = set_current_context(current)
|
||||
|
||||
# The original context will be restored when the deferred
|
||||
# completes, but there is nothing waiting for it, so it will
|
||||
# get leaked into the reactor or some other function which
|
||||
# wasn't expecting it. We therefore need to reset the context
|
||||
# here.
|
||||
#
|
||||
# (If this feels asymmetric, consider it this way: we are
|
||||
# effectively forking a new thread of execution. We are
|
||||
# probably currently within a ``with LoggingContext()`` block,
|
||||
# which is supposed to have a single entry and exit point. But
|
||||
# by spawning off another deferred, we are effectively
|
||||
# adding a new exit point.)
|
||||
d.addBoth(_set_context_cb, ctx)
|
||||
return d
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import Concatenate, Literal, ParamSpec
|
||||
from typing_extensions import Concatenate, Literal, ParamSpec, Unpack
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
@ -61,6 +61,7 @@ from twisted.python.failure import Failure
|
|||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
run_coroutine_in_background,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
@ -344,6 +345,7 @@ T1 = TypeVar("T1")
|
|||
T2 = TypeVar("T2")
|
||||
T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
|
||||
|
||||
@overload
|
||||
|
@ -402,6 +404,112 @@ def gather_results( # type: ignore[misc]
|
|||
return deferred.addCallback(tuple)
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
|
||||
) -> Tuple[Optional[T1]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[Optional[T1], Optional[T2]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
Optional[Coroutine[Any, Any, T5]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
|
||||
) -> Tuple[Optional[T1], ...]:
|
||||
"""Helper function that allows waiting on multiple coroutines at once.
|
||||
|
||||
The return value is a tuple of the return values of the coroutines in order.
|
||||
|
||||
If a `None` is passed instead of a coroutine, it will be ignored and a None
|
||||
is returned in the tuple.
|
||||
|
||||
Note: For typechecking we need to have an explicit overload for each
|
||||
distinct number of coroutines passed in. If you see type problems, it's
|
||||
likely because you're using many arguments and you need to add a new
|
||||
overload above.
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_coroutine_in_background(coroutine)
|
||||
for coroutine in coroutines
|
||||
if coroutine is not None
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
|
||||
results_iter = iter(results)
|
||||
return tuple(
|
||||
next(results_iter) if coroutine is not None else None
|
||||
for coroutine in coroutines
|
||||
)
|
||||
except defer.FirstError as dfe:
|
||||
# unwrap the error from defer.gatherResults.
|
||||
|
||||
# The raised exception's traceback only includes func() etc if
|
||||
# the 'await' happens before the exception is thrown - ie if the failure
|
||||
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
|
||||
# could be large.
|
||||
#
|
||||
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
|
||||
# we could throw Twisted into the fires of Mordor.
|
||||
|
||||
# suppress exception chaining, because the FirstError doesn't tell us anything
|
||||
# very interesting.
|
||||
assert isinstance(dfe.subFailure.value, BaseException)
|
||||
raise dfe.subFailure.value from None
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _LinearizerEntry:
|
||||
# The number of things executing.
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#
|
||||
#
|
||||
import traceback
|
||||
from typing import Generator, List, NoReturn, Optional
|
||||
from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
|
||||
|
||||
from parameterized import parameterized_class
|
||||
|
||||
|
@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
|
|||
ObservableDeferred,
|
||||
concurrently_execute,
|
||||
delay_cancellation,
|
||||
gather_optional_coroutines,
|
||||
stop_cancellation,
|
||||
timeout_deferred,
|
||||
)
|
||||
|
@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
|
|||
from tests.server import get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObservableDeferredTest(TestCase):
|
||||
def test_succeed(self) -> None:
|
||||
|
@ -588,3 +591,106 @@ class AwakenableSleeperTests(TestCase):
|
|||
sleeper.wake("name")
|
||||
self.assertTrue(d1.called)
|
||||
self.assertTrue(d2.called)
|
||||
|
||||
|
||||
class GatherCoroutineTests(TestCase):
|
||||
"""Tests for `gather_optional_coroutines`"""
|
||||
|
||||
def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
|
||||
"""Returns a coroutine and a deferred that it is waiting on to resolve"""
|
||||
|
||||
d: "defer.Deferred[T]" = defer.Deferred()
|
||||
|
||||
async def inner() -> T:
|
||||
with PreserveLoggingContext():
|
||||
return await d
|
||||
|
||||
return inner(), d
|
||||
|
||||
def test_single(self) -> None:
|
||||
"Test passing in a single coroutine works"
|
||||
|
||||
with LoggingContext("test_ctx") as text_ctx:
|
||||
deferred: "defer.Deferred[None]"
|
||||
coroutine, deferred = self.make_coroutine()
|
||||
|
||||
gather_deferred = defer.ensureDeferred(
|
||||
gather_optional_coroutines(coroutine)
|
||||
)
|
||||
|
||||
# We shouldn't have a result yet, and should be in the sentinel
|
||||
# context.
|
||||
self.assertNoResult(gather_deferred)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
# Resolving the deferred will resolve the coroutine
|
||||
deferred.callback(None)
|
||||
|
||||
# All coroutines have resolved, and so we should have the results
|
||||
result = self.successResultOf(gather_deferred)
|
||||
self.assertEqual(result, (None,))
|
||||
|
||||
# We should be back in the normal context.
|
||||
self.assertEqual(current_context(), text_ctx)
|
||||
|
||||
def test_multiple_resolve(self) -> None:
|
||||
"Test passing in multiple coroutine that all resolve works"
|
||||
|
||||
with LoggingContext("test_ctx") as test_ctx:
|
||||
deferred1: "defer.Deferred[int]"
|
||||
coroutine1, deferred1 = self.make_coroutine()
|
||||
deferred2: "defer.Deferred[str]"
|
||||
coroutine2, deferred2 = self.make_coroutine()
|
||||
|
||||
gather_deferred = defer.ensureDeferred(
|
||||
gather_optional_coroutines(coroutine1, coroutine2)
|
||||
)
|
||||
|
||||
# We shouldn't have a result yet, and should be in the sentinel
|
||||
# context.
|
||||
self.assertNoResult(gather_deferred)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
# Even if we resolve one of the coroutines, we shouldn't have a result
|
||||
# yet
|
||||
deferred2.callback("test")
|
||||
self.assertNoResult(gather_deferred)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
deferred1.callback(1)
|
||||
|
||||
# All coroutines have resolved, and so we should have the results
|
||||
result = self.successResultOf(gather_deferred)
|
||||
self.assertEqual(result, (1, "test"))
|
||||
|
||||
# We should be back in the normal context.
|
||||
self.assertEqual(current_context(), test_ctx)
|
||||
|
||||
def test_multiple_fail(self) -> None:
|
||||
"Test passing in multiple coroutine where one fails does the right thing"
|
||||
|
||||
with LoggingContext("test_ctx") as test_ctx:
|
||||
deferred1: "defer.Deferred[int]"
|
||||
coroutine1, deferred1 = self.make_coroutine()
|
||||
deferred2: "defer.Deferred[str]"
|
||||
coroutine2, deferred2 = self.make_coroutine()
|
||||
|
||||
gather_deferred = defer.ensureDeferred(
|
||||
gather_optional_coroutines(coroutine1, coroutine2)
|
||||
)
|
||||
|
||||
# We shouldn't have a result yet, and should be in the sentinel
|
||||
# context.
|
||||
self.assertNoResult(gather_deferred)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
# Throw an exception in one of the coroutines
|
||||
exc = Exception("test")
|
||||
deferred2.errback(exc)
|
||||
|
||||
# Expect the gather deferred to immediately fail
|
||||
result_exc = self.failureResultOf(gather_deferred)
|
||||
self.assertEqual(result_exc.value, exc)
|
||||
|
||||
# We should be back in the normal context.
|
||||
self.assertEqual(current_context(), test_ctx)
|
||||
|
|
Loading…
Reference in a new issue