0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-16 23:43:52 +01:00

Move server command handling out of TCP protocol (#7187)

This completes the merging of server and client command processing.
This commit is contained in:
Erik Johnston 2020-04-07 10:51:07 +01:00 committed by GitHub
parent 71953139d1
commit 82498ee901
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 238 additions and 270 deletions

1
changelog.d/7187.misc Normal file
View file

@ -0,0 +1 @@
Move server command handling out of TCP protocol.

View file

@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set
from prometheus_client import Counter from prometheus_client import Counter
from synapse.metrics import LaterGauge
from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command, Command,
FederationAckCommand, FederationAckCommand,
InvalidateCacheCommand, InvalidateCacheCommand,
@ -28,10 +30,12 @@ from synapse.replication.tcp.commands import (
RdataCommand, RdataCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
RemovePusherCommand, RemovePusherCommand,
ReplicateCommand,
SyncCommand, SyncCommand,
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -42,6 +46,13 @@ logger = logging.getLogger(__name__)
inbound_rdata_count = Counter( inbound_rdata_count = Counter(
"synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
) )
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
class ReplicationCommandHandler: class ReplicationCommandHandler:
@ -52,6 +63,10 @@ class ReplicationCommandHandler:
def __init__(self, hs): def __init__(self, hs):
self._replication_data_handler = hs.get_replication_data_handler() self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
self._notifier = hs.get_notifier()
self._clock = hs.get_clock()
self._instance_id = hs.get_instance_id()
# Set of streams that we've caught up with. # Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str] self._streams_connected = set() # type: Set[str]
@ -69,8 +84,26 @@ class ReplicationCommandHandler:
# The factory used to create connections. # The factory used to create connections.
self._factory = None # type: Optional[ReplicationClientFactory] self._factory = None # type: Optional[ReplicationClientFactory]
# The current connection. None if we are currently (re)connecting # The currently connected connections.
self._connection = None self._connections = [] # type: List[AbstractConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self._connections),
)
self._is_master = hs.config.worker_app is None
self._federation_sender = None
if self._is_master and not hs.config.send_federation:
self._federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs): def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server """Helper method to start a replication connection to the remote server
@ -82,6 +115,70 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory) hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
return
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
await self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand): async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc() inbound_rdata_count.labels(stream_name).inc()
@ -174,6 +271,9 @@ class ReplicationCommandHandler:
""""Called when get a new REMOTE_SERVER_UP command.""" """"Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data) self._replication_data_handler.on_remote_server_up(cmd.data)
if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called """Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the when a connection has been established and we need to send the
@ -181,29 +281,63 @@ class ReplicationCommandHandler:
""" """
return self._presence_handler.get_currently_syncing_users() return self._presence_handler.get_currently_syncing_users()
def update_connection(self, connection): def new_connection(self, connection: AbstractConnection):
"""Called when a connection has been established (or lost with None). """Called when we have a new connection.
""" """
self._connection = connection self._connections.append(connection)
def finished_connecting(self): # If we are connected to replication as a client (rather than a server)
"""Called when we have successfully subscribed and caught up to all # we need to reset the reconnection delay on the client factory (which
streams we're interested in. # is used to do exponential back off when the connection drops).
""" #
logger.info("Finished connecting to server") # Ideally we would reset the delay when we've "fully established" the
# connection (for some definition thereof) to stop us from tightlooping
# We don't reset the delay any earlier as otherwise if there is a # on reconnection if something fails after this point and we drop the
# problem during start up we'll end up tight looping connecting to the # connection. Unfortunately, we don't really have a better definition of
# server. # "fully established" than the connection being established.
if self._factory: if self._factory:
self._factory.resetDelay() self._factory.resetDelay()
def send_command(self, cmd: Command): # Tell the server if we have any users currently syncing (should only
"""Send a command to master (when we get establish a connection if we # happen on synchrotrons)
don't have one already.) currently_syncing = self.get_currently_syncing_users()
now = self._clock.time_msec()
for user_id in currently_syncing:
connection.send_command(
UserSyncCommand(self._instance_id, user_id, True, now)
)
def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
""" """
if self._connection: try:
self._connection.send_command(cmd) self._connections.remove(connection)
except ValueError:
pass
def connected(self) -> bool:
"""Do we have any replication connections open?
Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
"""
return bool(self._connections)
def send_command(self, cmd: Command):
"""Send a command to all connected connections.
"""
if self._connections:
for connection in self._connections:
try:
connection.send_command(cmd)
except Exception:
# We probably want to catch some types of exceptions here
# and log them as warnings (e.g. connection gone), but I
# can't find what those exception types they would be.
logger.exception(
"Failed to write command %s to connection %s",
cmd.NAME,
connection,
)
else: else:
logger.warning("Dropping command as not connected: %r", cmd.NAME) logger.warning("Dropping command as not connected: %r", cmd.NAME)
@ -250,3 +384,10 @@ class ReplicationCommandHandler:
def send_remote_server_up(self, server: str): def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server)) self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))

View file

@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping > ERROR server stopping
* connection closed by server * * connection closed by server *
""" """
import abc
import fcntl import fcntl
import logging import logging
import struct import struct
@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import (
ErrorCommand, ErrorCommand,
NameCommand, NameCommand,
PingCommand, PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand, ReplicateCommand,
ServerCommand, ServerCommand,
SyncCommand,
UserSyncCommand,
) )
from synapse.types import Collection from synapse.types import Collection
from synapse.util import Clock from synapse.util import Clock
@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
are only sent by the server. are only sent by the server.
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command. command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
It also sends `PING` periodically, and correctly times out remote connections It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command) (if they send a `PING` command)
@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000 max_line_buffer = 10000
def __init__(self, clock): def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
self.clock = clock self.clock = clock
self.command_handler = handler
self.last_received_command = self.clock.time_msec() self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0 self.last_sent_command = 0
@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out. # can time us out.
self.send_command(PingCommand(self.clock.time_msec())) self.send_command(PingCommand(self.clock.time_msec()))
self.command_handler.new_connection(self)
def send_ping(self): def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection """Periodically sends a ping and checks if we should close the connection
due to the other side timing out. due to the other side timing out.
@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
async def handle_command(self, cmd: Command): async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream. """Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable. First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
Args: Args:
cmd: received command cmd: received command
""" """
handler = getattr(self, "on_%s" % (cmd.NAME,)) handled = False
await handler(cmd)
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def close(self): def close(self):
logger.warning("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED self.state = ConnectionStates.CLOSED
self.pending_commands = [] self.pending_commands = []
self.command_handler.lost_connection(self)
if self.transport: if self.transport:
self.transport.unregisterProducer() self.transport.unregisterProducer()
@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
def __init__(self, server_name, clock, streamer): def __init__(
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
):
super().__init__(clock, handler)
self.server_name = server_name self.server_name = server_name
self.streamer = streamer
def connectionMade(self): def connectionMade(self):
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self) super().connectionMade()
self.streamer.new_connection(self)
async def on_NAME(self, cmd): async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data) logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data self.name = cmd.data
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd):
await self.streamer.on_clear_user_syncs(cmd.instance_id)
async def on_REPLICATE(self, cmd):
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
current_token = self.streamer.get_stream_token(stream_name)
self.send_command(PositionCommand(stream_name, current_token))
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, token, data))
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
clock: Clock, clock: Clock,
command_handler: "ReplicationCommandHandler", command_handler: "ReplicationCommandHandler",
): ):
BaseReplicationStreamProtocol.__init__(self, clock) super().__init__(clock, command_handler)
self.instance_id = hs.get_instance_id()
self.client_name = client_name self.client_name = client_name
self.server_name = server_name self.server_name = server_name
self.handler = command_handler
def connectionMade(self): def connectionMade(self):
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self) super().connectionMade()
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
self.replicate() self.replicate()
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
self.handler.finished_connecting()
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
Delegates to `command_handler.on_<COMMAND>`, which must return an
awaitable.
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
async def on_SERVER(self, cmd): async def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand()) self.send_command(ReplicateCommand())
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self) class AbstractConnection(abc.ABC):
self.handler.update_connection(None) """An interface for replication connections.
"""
@abc.abstractmethod
def send_command(self, cmd: Command):
"""Send the command down the connection
"""
pass
# This tells python that `BaseReplicationStreamProtocol` implements the
# interface.
AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections # The following simply registers metrics for the replication connections

View file

@ -17,7 +17,7 @@
import logging import logging
import random import random
from typing import Any, Dict, List from typing import Dict
from six import itervalues from six import itervalues
@ -25,24 +25,14 @@ from prometheus_client import Counter
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from .protocol import ServerReplicationStreamProtocol from synapse.util.metrics import Measure
from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream
stream_updates_counter = Counter( stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
) )
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,13 +42,23 @@ class ReplicationStreamProtocolFactory(Factory):
""" """
def __init__(self, hs): def __init__(self, hs):
self.streamer = hs.get_replication_streamer() self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
# If we've created a `ReplicationStreamProtocolFactory` then we're
# almost certainly registering a replication listener, so let's ensure
# that we've started a `ReplicationStreamer` instance to actually push
# data.
#
# (This is a bit of a weird place to do this, but the alternatives such
# as putting this in `HomeServer.setup()`, requires either passing the
# listener config again or always starting a `ReplicationStreamer`.)
hs.get_replication_streamer()
def buildProtocol(self, addr): def buildProtocol(self, addr):
return ServerReplicationStreamProtocol( return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.streamer self.server_name, self.clock, self.command_handler
) )
@ -78,16 +78,6 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
# List of streams that clients can subscribe to. # List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been # We only support federation stream if federation sending hase been
# disabled on the master. # disabled on the master.
@ -104,18 +94,12 @@ class ReplicationStreamer(object):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke) self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates # Keeps track of whether we are currently checking for updates
self.is_looping = False self.is_looping = False
self.pending_updates = False self.pending_updates = False
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) self.command_handler = hs.get_tcp_replication()
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance. """Get a mapp from stream name to stream instance.
@ -129,7 +113,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed is currently being executed, so that nothing gets missed
""" """
if not self.connections: if not self.command_handler.connected():
# Don't bother if nothing is listening. We still need to advance # Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever # the stream tokens otherwise they'll fall beihind forever
for stream in self.streams: for stream in self.streams:
@ -186,9 +170,7 @@ class ReplicationStreamer(object):
raise raise
logger.debug( logger.debug(
"Sending %d updates to %d connections", "Sending %d updates", len(updates),
len(updates),
len(self.connections),
) )
if updates: if updates:
@ -204,112 +186,19 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details. # token. See RdataCommand for more details.
batched_updates = _batch_updates(updates) batched_updates = _batch_updates(updates)
for conn in self.connections: for token, row in batched_updates:
for token, row in batched_updates: try:
try: self.command_handler.stream_update(
conn.stream_update(stream.NAME, token, row) stream.NAME, token, row
except Exception: )
logger.exception("Failed to replicate") except Exception:
logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop") logger.debug("No more pending updates, breaking poke loop")
finally: finally:
self.pending_updates = False self.pending_updates = False
self.is_looping = False self.is_looping = False
def get_stream_token(self, stream_name):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
stream = self.streams_by_name.get(stream_name, None)
if not stream:
raise Exception("unknown stream %s", stream_name)
return stream.current_token()
@measure_func("repl.federation_ack")
def federation_ack(self, token):
"""We've received an ack for federation stream from a client.
"""
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
instance_id, user_id, is_syncing, last_sync_ms
)
async def on_clear_user_syncs(self, instance_id):
"""A replication client wants us to drop all their UserSync data.
"""
await self.presence_handler.update_external_syncs_clear(instance_id)
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
Used in tests.
"""
for conn in self.connections:
conn.send_sync(data)
def new_connection(self, connection):
"""A new client connection has been established
"""
self.connections.append(connection)
def lost_connection(self, connection):
"""A client connection has been lost
"""
try:
self.connections.remove(connection)
except ValueError:
pass
def _batch_updates(updates): def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to """Takes a list of updates of form [(token, row)] and sets the token to