Handle replication commands synchronously where possible (#7876)

Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
This commit is contained in:
Richard van der Hoff 2020-07-27 18:54:43 +01:00 committed by GitHub
parent 7c2e2c2077
commit f57b99af22
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 86 deletions

1
changelog.d/7876.bugfix Normal file
View file

@ -0,0 +1 @@
Fix an `AssertionError` exception introduced in v1.18.0rc1.

1
changelog.d/7876.misc Normal file
View file

@ -0,0 +1 @@
Further optimise queueing of inbound replication commands.

View file

@ -16,6 +16,7 @@
import logging import logging
from typing import ( from typing import (
Any, Any,
Awaitable,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
@ -33,6 +34,7 @@ from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
ClearUserSyncsCommand, ClearUserSyncsCommand,
@ -152,7 +154,7 @@ class ReplicationCommandHandler:
# When POSITION or RDATA commands arrive, we stick them in a queue and process # When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process. # them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_stream # the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str] self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the # for each stream, a queue of commands that are awaiting processing, and the
@ -185,7 +187,7 @@ class ReplicationCommandHandler:
if self._is_master: if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
async def _add_command_to_stream_queue( def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None: ) -> None:
"""Queue the given received command for processing """Queue the given received command for processing
@ -199,33 +201,34 @@ class ReplicationCommandHandler:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return return
# if we're already processing this stream, stick the new command in the
# queue, and we're done.
if stream_name in self._processing_streams:
queue.append((cmd, conn)) queue.append((cmd, conn))
# if we're already processing this stream, there's nothing more to do:
# the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
return return
# otherwise, process the new command. # fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, stream_name
)
# arguably we should start off a new background process here, but nothing async def _unsafe_process_queue(self, stream_name: str):
# will be too upset if we don't return for ages, so let's save the overhead """Processes the command queue for the given stream, until it is empty
# and use the existing logcontext.
Does not check if there is already a thread processing the queue, hence "unsafe"
"""
assert stream_name not in self._processing_streams
self._processing_streams.add(stream_name) self._processing_streams.add(stream_name)
try: try:
# might as well skip the queue for this one, since it must be empty queue = self._command_queues_by_stream.get(stream_name)
assert not queue
await self._process_command(cmd, conn, stream_name)
# now process any other commands that have built up while we were
# dealing with that one.
while queue: while queue:
cmd, conn = queue.popleft() cmd, conn = queue.popleft()
try: try:
await self._process_command(cmd, conn, stream_name) await self._process_command(cmd, conn, stream_name)
except Exception: except Exception:
logger.exception("Failed to handle command %s", cmd) logger.exception("Failed to handle command %s", cmd)
finally: finally:
self._processing_streams.discard(stream_name) self._processing_streams.discard(stream_name)
@ -299,7 +302,7 @@ class ReplicationCommandHandler:
""" """
return self._streams_to_replicate return self._streams_to_replicate
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn) self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection): def send_positions_to_connection(self, conn: AbstractConnection):
@ -318,44 +321,60 @@ class ReplicationCommandHandler:
) )
) )
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc() user_sync_counter.inc()
if self._is_master: if self._is_master:
await self._presence_handler.update_external_syncs_row( return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
) )
else:
return None
async def on_CLEAR_USER_SYNC( def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
): ) -> Optional[Awaitable[None]]:
if self._is_master: if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id) return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
async def on_FEDERATION_ACK( def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
self, conn: AbstractConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc() federation_ack_counter.inc()
if self._federation_sender: if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token) self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
async def on_REMOVE_PUSHER( def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand self, conn: AbstractConnection, cmd: RemovePusherCommand
): ) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc() remove_pusher_counter.inc()
if self._is_master: if self._is_master:
return self._handle_remove_pusher(cmd)
else:
return None
async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id( await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
) )
self._notifier.on_new_replication_data() self._notifier.on_new_replication_data()
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
if self._is_master: if self._is_master:
return self._handle_user_ip(cmd)
else:
return None
async def _handle_user_ip(self, cmd: UserIpCommand):
await self._store.insert_client_ip( await self._store.insert_client_ip(
cmd.user_id, cmd.user_id,
cmd.access_token, cmd.access_token,
@ -365,10 +384,10 @@ class ReplicationCommandHandler:
cmd.last_seen, cmd.last_seen,
) )
if self._server_notices_sender: assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id) await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes # Ignore RDATA that are just our own echoes
return return
@ -382,7 +401,7 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching # 2. so we don't race with getting a POSITION command and fetching
# missing RDATA. # missing RDATA.
await self._add_command_to_stream_queue(conn, cmd) self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata( async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@ -459,14 +478,14 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
) )
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes # Ignore POSITION that are just our own echoes
return return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
await self._add_command_to_stream_queue(conn, cmd) self._add_command_to_stream_queue(conn, cmd)
async def _process_position( async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@ -526,9 +545,7 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name) self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP( def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command.""" """"Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data) self._replication_data_handler.on_remote_server_up(cmd.data)

View file

@ -50,6 +50,7 @@ import abc
import fcntl import fcntl
import logging import logging
import struct import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from prometheus_client import Counter from prometheus_client import Counter
@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command) (if they send a `PING` command)
@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.
self._logging_context = BackgroundProcessLoggingContext( ctx_name = "replication-conn-%s" % self.conn_id
"replication_command_handler-%s" % self.conn_id self._logging_context = BackgroundProcessLoggingContext(ctx_name)
) self._logging_context.request = ctx_name
def connectionMade(self): def connectionMade(self):
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())
@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
# Now lets try and call on_<CMD_NAME> function self.handle_command(cmd)
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
async def handle_command(self, cmd: Command): def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream. """Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for `self.command_handler.on_<COMMAND>` if it exists (which can optionally
protocol level handling of commands (e.g. PINGs), before delegating to return an Awaitable).
the handler.
This allows for protocol level handling of commands (e.g. PINGs), before
delegating to the handler.
Args: Args:
cmd: received command cmd: received command
@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling. # specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func: if cmd_func:
await cmd_func(cmd) cmd_func(cmd)
handled = True handled = True
# Then call out to the handler. # Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func: if cmd_func:
await cmd_func(self, cmd) res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
handled = True handled = True
if not handled: if not handled:
@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending: for cmd in pending:
self.send_command(cmd) self.send_command(cmd)
async def on_PING(self, line): def on_PING(self, line):
self.received_ping = True self.received_ping = True
async def on_ERROR(self, cmd): def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self): def pauseProducing(self):
@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
super().connectionMade() super().connectionMade()
async def on_NAME(self, cmd): def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data) logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data self.name = cmd.data
@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
self.replicate() self.replicate()
async def on_SERVER(self, cmd): def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote") self.send_error("Wrong remote")

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from inspect import isawaitable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import txredisapi import txredisapi
@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances. # remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
# Now lets try and call on_<CMD_NAME> function self.handle_command(cmd)
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
async def handle_command(self, cmd: Command): def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream. """Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable. Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
Awaitable).
Args: Args:
cmd: received command cmd: received command
""" """
handled = False
# First call any command handlers on this instance. These are for redis
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func: if not cmd_func:
await cmd_func(self, cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
return
res = cmd_func(self, cmd)
# the handler might be a coroutine: fire it off as a background process
# if so.
if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis") logger.info("Lost connection to redis")