Periodically send pings to detect dead Redis connections (#9218)

This is done by creating a custom `RedisFactory` subclass that
periodically pings all connections in its pool.

We also ensure that the `replyTimeout` param is non-null, so that we
timeout waiting for the reply to those pings (and thus triggering a
reconnect).
This commit is contained in:
Erik Johnston 2021-01-26 10:54:54 +00:00 committed by GitHub
parent 5b857b77f7
commit a1ff1e967f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 107 additions and 57 deletions

1
changelog.d/9218.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.

View file

@ -19,8 +19,9 @@ from typing import List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
class SubscriberProtocol:
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
@ -39,14 +40,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...
class SubscriberFactory:
def buildProtocol(self, addr): ...
class ConnectionHandler: ...
class RedisFactory:
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
@ -59,3 +59,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@ -300,7 +304,7 @@ class ReplicationCommandHandler:
# First create the connection for sending commands.
outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
hs=hs,
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,

View file

@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to and publish from
(not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
handler = None # type: ReplicationCommandHandler
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
synapse_handler = None # type: ReplicationCommandHandler
synapse_stream_name = None # type: str
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
self.handler.new_connection(self)
self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
self.handler.send_positions_to_connection(self)
self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
)
)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
class SynapseRedisFactory(txredisapi.RedisFactory):
"""A subclass of RedisFactory that periodically sends pings to ensure that
we detect dead connections.
"""
def __init__(
self,
hs: "HomeServer",
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = txredisapi.ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: int = 30,
convertNumbers: Optional[int] = True,
):
super().__init__(
uuid=uuid,
dbid=dbid,
poolsize=poolsize,
isLazy=isLazy,
handler=handler,
charset=charset,
password=password,
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
async def _send_ping(self):
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
except Exception:
logger.warning("Failed to send ping to a redis connection")
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
super().__init__()
super().__init__(
hs,
uuid="subscriber",
dbid=None,
poolsize=1,
replyTimeout=30,
password=hs.config.redis.redis_password,
)
# This sets the password on the RedisFactory base class (as
# SubscriberFactory constructor doesn't pass it through).
self.password = hs.config.redis.redis_password
self.synapse_handler = hs.get_tcp_replication()
self.synapse_stream_name = hs.hostname
self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.outbound_redis_connection = outbound_redis_connection
self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
p = super().buildProtocol(addr) # type: RedisSubscriber
p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
p.password = self.password
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
reactor,
hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
charset: str = "utf-8",
password: Optional[str] = None,
connectTimeout: Optional[int] = None,
replyTimeout: Optional[int] = None,
convertNumbers: bool = True,
replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
reactor.
"""Creates a connection to Redis that is lazily set up and reconnects if the
connections is lost.
"""
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
uuid,
dbid,
poolsize,
isLazy,
txredisapi.ConnectionHandler,
charset,
password,
replyTimeout,
convertNumbers,
factory = SynapseRedisFactory(
hs,
uuid=uuid,
dbid=dbid,
poolsize=1,
isLazy=True,
handler=txredisapi.ConnectionHandler,
password=password,
replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)
reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, 30)
return factory.handler