# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from inspect import isawaitable from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast import attr import txredisapi from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import IAddress, IConnector from twisted.python.failure import Failure from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, parse_command_from_line, ) from synapse.replication.tcp.context import ClientContextFactory from synapse.replication.tcp.protocol import ( IReplicationConnection, tcp_inbound_commands_counter, tcp_outbound_commands_counter, ) if TYPE_CHECKING: from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer logger = logging.getLogger(__name__) T = TypeVar("T") V = TypeVar("V") @attr.s class ConstantProperty(Generic[T, V]): """A descriptor that returns the given constant, ignoring attempts to set it. """ constant: V = attr.ib() def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant def __set__(self, obj: Optional[T], value: V) -> None: pass @implementer(IReplicationConnection) class RedisSubscriber(txredisapi.SubscriberProtocol): """Connection to redis subscribed to replication stream. This class fulfils two functions: (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis connection, parsing *incoming* messages into replication commands, and passing them to `ReplicationCommandHandler` (b) it implements the IReplicationConnection API, where it sends *outgoing* commands onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom constructor, so instead we expect the defined attributes below to be set immediately after initialisation. Attributes: synapse_handler: The command handler to handle incoming commands. synapse_stream_prefix: The *redis* stream name to subscribe to and publish from (not anything to do with Synapse replication streams). synapse_outbound_redis_connection: The connection to redis to use to send commands. """ synapse_handler: "ReplicationCommandHandler" synapse_stream_prefix: str synapse_channel_names: List[str] synapse_outbound_redis_connection: txredisapi.ConnectionHandler def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. with PreserveLoggingContext(): # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to # capture the sentinel context as its containing context and won't prevent # GC of / unintentionally reactivate what would be the current context. self._logging_context = BackgroundProcessLoggingContext( "replication_command_handler" ) def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() run_as_background_process("subscribe-replication", self._send_subscribe) async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. fully_qualified_stream_names = [ f"{self.synapse_stream_prefix}/{stream_suffix}" for stream_suffix in self.synapse_channel_names ] + [self.synapse_stream_prefix] logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names) await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names)) logger.info( "Successfully subscribed to redis stream, sending REPLICATE command" ) self.synapse_handler.new_connection(self) await self._async_send_command(ReplicateCommand()) logger.info("REPLICATE successfully sent") # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) def _parse_and_dispatch_message(self, message: str) -> None: if message.strip() == "": # Ignore blank lines return try: cmd = parse_command_from_line(message) except Exception: logger.exception( "Failed to parse replication line: %r", message, ) return # We use "redis" as the name here as we don't have 1:1 connections to # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() self.handle_command(cmd) def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. Delegates to `self.handler.on_` (which can optionally return an Awaitable). Args: cmd: received command """ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) if not cmd_func: logger.warning("Unhandled command: %r", cmd) return res = cmd_func(self, cmd) # the handler might be a coroutine: fire it off as a background process # if so. if isawaitable(res): run_as_background_process( "replication-" + cmd.get_logcontext_id(), lambda: res ) def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("Lost connection to redis") super().connectionLost(reason) self.synapse_handler.lost_connection(self) # mark the logging context as finished by triggering `__exit__()` with PreserveLoggingContext(): with self._logging_context: pass # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. def send_command(self, cmd: Command) -> None: """Send a command if connection has been established. Args: cmd: The command to send """ run_as_background_process( "send-cmd", self._async_send_command, cmd, bg_start_span=False ) async def _async_send_command(self, cmd: Command) -> None: """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) encoded_string = string.encode("utf-8") # We use "redis" as the name here as we don't have 1:1 connections to # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() channel_name = cmd.redis_channel_name(self.synapse_stream_prefix) await make_deferred_yieldable( self.synapse_outbound_redis_connection.publish(channel_name, encoded_string) ) class SynapseRedisFactory(txredisapi.RedisFactory): """A subclass of RedisFactory that periodically sends pings to ensure that we detect dead connections. """ # We want to *always* retry connecting, txredisapi will stop if there is a # failure during certain operations, e.g. during AUTH. continueTrying = cast(bool, ConstantProperty(True)) def __init__( self, hs: "HomeServer", uuid: str, dbid: Optional[int], poolsize: int, isLazy: bool = False, handler: Type = txredisapi.ConnectionHandler, charset: str = "utf-8", password: Optional[str] = None, replyTimeout: int = 30, convertNumbers: Optional[int] = True, ): super().__init__( uuid=uuid, dbid=dbid, poolsize=poolsize, isLazy=isLazy, handler=handler, charset=charset, password=password, replyTimeout=replyTimeout, convertNumbers=convertNumbers, ) hs.get_clock().looping_call(self._send_ping, 30 * 1000) @wrap_as_background_process("redis_ping") async def _send_ping(self) -> None: for connection in self.pool: try: await make_deferred_yieldable(connection.ping()) except Exception: logger.warning("Failed to send ping to a redis connection") # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but # it's rubbish. We add our own here. def startedConnecting(self, connector: IConnector) -> None: logger.info( "Connecting to redis server %s", format_address(connector.getDestination()) ) super().startedConnecting(connector) def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s failed: %s", format_address(connector.getDestination()), reason.value, ) super().clientConnectionFailed(connector, reason) def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s lost: %s", format_address(connector.getDestination()), reason.value, ) super().clientConnectionLost(connector, reason) def format_address(address: IAddress) -> str: if isinstance(address, (IPv4Address, IPv6Address)): return "%s:%i" % (address.host, address.port) return str(address) class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately subscribes to some streams. Args: hs outbound_redis_connection: A connection to redis that will be used to send outbound commands (this is separate to the redis connection used to subscribe). channel_names: A list of channel names to append to the base channel name to additionally subscribe to. e.g. if ['ABC', 'DEF'] is specified then we'll listen to: example.com; example.com/ABC; and example.com/DEF. """ maxDelay = 5 protocol = RedisSubscriber def __init__( self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler, channel_names: List[str], ): super().__init__( hs, uuid="subscriber", dbid=None, poolsize=1, replyTimeout=30, password=hs.config.redis.redis_password, ) self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_prefix = hs.hostname self.synapse_channel_names = channel_names self.synapse_outbound_redis_connection = outbound_redis_connection def buildProtocol(self, addr: IAddress) -> RedisSubscriber: p = super().buildProtocol(addr) p = cast(RedisSubscriber, p) # We do this here rather than add to the constructor of `RedisSubcriber` # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. p.synapse_handler = self.synapse_handler p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.synapse_stream_prefix = self.synapse_stream_prefix p.synapse_channel_names = self.synapse_channel_names return p def lazyConnection( hs: "HomeServer", host: str = "localhost", port: int = 6379, dbid: Optional[int] = None, reconnect: bool = True, password: Optional[str] = None, replyTimeout: int = 30, ) -> txredisapi.ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the connections is lost. """ uuid = "%s:%d" % (host, port) factory = SynapseRedisFactory( hs, uuid=uuid, dbid=dbid, poolsize=1, isLazy=True, handler=txredisapi.ConnectionHandler, password=password, replyTimeout=replyTimeout, ) factory.continueTrying = reconnect reactor = hs.get_reactor() if hs.config.redis.redis_use_tls: ssl_context_factory = ClientContextFactory(hs.config.redis) reactor.connectSSL( host, port, factory, ssl_context_factory, timeout=30, bindAddress=None, ) else: reactor.connectTCP( host, port, factory, timeout=30, bindAddress=None, ) return factory.handler