mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 11:13:49 +01:00
Handle replication commands synchronously where possible (#7876)
Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
This commit is contained in:
parent
7c2e2c2077
commit
f57b99af22
5 changed files with 113 additions and 86 deletions
1
changelog.d/7876.bugfix
Normal file
1
changelog.d/7876.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix an `AssertionError` exception introduced in v1.18.0rc1.
|
1
changelog.d/7876.misc
Normal file
1
changelog.d/7876.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Further optimise queueing of inbound replication commands.
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
|
@ -33,6 +34,7 @@ from typing_extensions import Deque
|
|||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
|
@ -152,7 +154,7 @@ class ReplicationCommandHandler:
|
|||
# When POSITION or RDATA commands arrive, we stick them in a queue and process
|
||||
# them in order in a separate background process.
|
||||
|
||||
# the streams which are currently being processed by _unsafe_process_stream
|
||||
# the streams which are currently being processed by _unsafe_process_queue
|
||||
self._processing_streams = set() # type: Set[str]
|
||||
|
||||
# for each stream, a queue of commands that are awaiting processing, and the
|
||||
|
@ -185,7 +187,7 @@ class ReplicationCommandHandler:
|
|||
if self._is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
async def _add_command_to_stream_queue(
|
||||
def _add_command_to_stream_queue(
|
||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
) -> None:
|
||||
"""Queue the given received command for processing
|
||||
|
@ -199,33 +201,34 @@ class ReplicationCommandHandler:
|
|||
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
|
||||
return
|
||||
|
||||
# if we're already processing this stream, stick the new command in the
|
||||
# queue, and we're done.
|
||||
if stream_name in self._processing_streams:
|
||||
queue.append((cmd, conn))
|
||||
|
||||
# if we're already processing this stream, there's nothing more to do:
|
||||
# the new entry on the queue will get picked up in due course
|
||||
if stream_name in self._processing_streams:
|
||||
return
|
||||
|
||||
# otherwise, process the new command.
|
||||
# fire off a background process to start processing the queue.
|
||||
run_as_background_process(
|
||||
"process-replication-data", self._unsafe_process_queue, stream_name
|
||||
)
|
||||
|
||||
# arguably we should start off a new background process here, but nothing
|
||||
# will be too upset if we don't return for ages, so let's save the overhead
|
||||
# and use the existing logcontext.
|
||||
async def _unsafe_process_queue(self, stream_name: str):
|
||||
"""Processes the command queue for the given stream, until it is empty
|
||||
|
||||
Does not check if there is already a thread processing the queue, hence "unsafe"
|
||||
"""
|
||||
assert stream_name not in self._processing_streams
|
||||
|
||||
self._processing_streams.add(stream_name)
|
||||
try:
|
||||
# might as well skip the queue for this one, since it must be empty
|
||||
assert not queue
|
||||
await self._process_command(cmd, conn, stream_name)
|
||||
|
||||
# now process any other commands that have built up while we were
|
||||
# dealing with that one.
|
||||
queue = self._command_queues_by_stream.get(stream_name)
|
||||
while queue:
|
||||
cmd, conn = queue.popleft()
|
||||
try:
|
||||
await self._process_command(cmd, conn, stream_name)
|
||||
except Exception:
|
||||
logger.exception("Failed to handle command %s", cmd)
|
||||
|
||||
finally:
|
||||
self._processing_streams.discard(stream_name)
|
||||
|
||||
|
@ -299,7 +302,7 @@ class ReplicationCommandHandler:
|
|||
"""
|
||||
return self._streams_to_replicate
|
||||
|
||||
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
||||
|
@ -318,44 +321,60 @@ class ReplicationCommandHandler:
|
|||
)
|
||||
)
|
||||
|
||||
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
|
||||
def on_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_row(
|
||||
return self._presence_handler.update_external_syncs_row(
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def on_CLEAR_USER_SYNC(
|
||||
def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
):
|
||||
) -> Optional[Awaitable[None]]:
|
||||
if self._is_master:
|
||||
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def on_FEDERATION_ACK(
|
||||
self, conn: AbstractConnection, cmd: FederationAckCommand
|
||||
):
|
||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(
|
||||
def on_REMOVE_PUSHER(
|
||||
self, conn: AbstractConnection, cmd: RemovePusherCommand
|
||||
):
|
||||
) -> Optional[Awaitable[None]]:
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
return self._handle_remove_pusher(cmd)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
|
||||
await self._store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
||||
)
|
||||
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
def on_USER_IP(
|
||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
return self._handle_user_ip(cmd)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _handle_user_ip(self, cmd: UserIpCommand):
|
||||
await self._store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
|
@ -365,10 +384,10 @@ class ReplicationCommandHandler:
|
|||
cmd.last_seen,
|
||||
)
|
||||
|
||||
if self._server_notices_sender:
|
||||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
|
@ -382,7 +401,7 @@ class ReplicationCommandHandler:
|
|||
# 2. so we don't race with getting a POSITION command and fetching
|
||||
# missing RDATA.
|
||||
|
||||
await self._add_command_to_stream_queue(conn, cmd)
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_rdata(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
||||
|
@ -459,14 +478,14 @@ class ReplicationCommandHandler:
|
|||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
|
||||
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
|
||||
|
||||
await self._add_command_to_stream_queue(conn, cmd)
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_position(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
||||
|
@ -526,9 +545,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(
|
||||
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ import abc
|
|||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
|
||||
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
|
||||
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
|
||||
if so, that will get run as a background process.
|
||||
|
||||
It also sends `PING` periodically, and correctly times out remote connections
|
||||
(if they send a `PING` command)
|
||||
|
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||
# background process so that the CPU stats get reported to prometheus.
|
||||
self._logging_context = BackgroundProcessLoggingContext(
|
||||
"replication_command_handler-%s" % self.conn_id
|
||||
)
|
||||
ctx_name = "replication-conn-%s" % self.conn_id
|
||||
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
|
||||
self._logging_context.request = ctx_name
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("[%s] Connection established", self.id())
|
||||
|
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
self.handle_command(cmd)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
def handle_command(self, cmd: Command) -> None:
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
First calls `self.on_<COMMAND>` if it exists, then calls
|
||||
`self.command_handler.on_<COMMAND>` if it exists. This allows for
|
||||
protocol level handling of commands (e.g. PINGs), before delegating to
|
||||
the handler.
|
||||
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
|
||||
return an Awaitable).
|
||||
|
||||
This allows for protocol level handling of commands (e.g. PINGs), before
|
||||
delegating to the handler.
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
|
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
res = cmd_func(self, cmd)
|
||||
|
||||
# the handler might be a coroutine: fire it off as a background process
|
||||
# if so.
|
||||
|
||||
if isawaitable(res):
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||
)
|
||||
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
|
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||
for cmd in pending:
|
||||
self.send_command(cmd)
|
||||
|
||||
async def on_PING(self, line):
|
||||
def on_PING(self, line):
|
||||
self.received_ping = True
|
||||
|
||||
async def on_ERROR(self, cmd):
|
||||
def on_ERROR(self, cmd):
|
||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||
|
||||
def pauseProducing(self):
|
||||
|
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.send_command(ServerCommand(self.server_name))
|
||||
super().connectionMade()
|
||||
|
||||
async def on_NAME(self, cmd):
|
||||
def on_NAME(self, cmd):
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
|
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
# Once we've connected subscribe to the necessary streams
|
||||
self.replicate()
|
||||
|
||||
async def on_SERVER(self, cmd):
|
||||
def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import txredisapi
|
||||
|
@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
# remote instances.
|
||||
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
self.handle_command(cmd)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
def handle_command(self, cmd: Command) -> None:
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
|
||||
Awaitable).
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for redis
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(self, cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
if not cmd_func:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
return
|
||||
|
||||
res = cmd_func(self, cmd)
|
||||
|
||||
# the handler might be a coroutine: fire it off as a background process
|
||||
# if so.
|
||||
|
||||
if isawaitable(res):
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||
)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis")
|
||||
|
|
Loading…
Reference in a new issue