Add missing type hints to synapse.replication. (#11938)

This commit is contained in:
Patrick Cloke 2022-02-08 11:03:08 -05:00 committed by GitHub
parent 8c94b3abe9
commit d0e78af35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 209 additions and 147 deletions

1
changelog.d/11938.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to replication code.

View file

@ -169,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.push.*]
disallow_untyped_defs = True
[mypy-synapse.replication.*]
disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True

View file

@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, instance_name: Optional[str], new_id: int):
def advance(self, instance_name: Optional[str], new_id: int) -> None:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:

View file

@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
cache_name="client_ip_last_seen", max_size=50000
)
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows):
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
self._group_updates_id_gen.get_current_token(),
)
def get_group_stream_token(self):
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -20,10 +21,12 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) # type: ignore
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -14,10 +14,12 @@
"""A replication client for use by synapse workers.
"""
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes
from synapse.federation import send_queue
@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
def startedConnecting(self, connector: IConnector) -> None:
logger.info("Connecting to replication: %r", connector.getDestination())
def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.hs,
@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
self.command_handler,
)
def clientConnectionLost(self, connector, reason):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.error("Lost replication conn: %r", reason)
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
def clientConnectionFailed(self, connector, reason):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
@ -131,7 +133,7 @@ class ReplicationDataHandler:
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@ -252,14 +254,16 @@ class ReplicationDataHandler:
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int):
async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
await self.on_rdata(stream_name, instance_name, token, [])
# We poke the generic "replication" notifier to wake anything up that
# may be streaming.
self.notifier.notify_replication()
def on_remote_server_up(self, server: str):
def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
# Let's wake up the transaction queue for the server in case we have
@ -269,7 +273,7 @@ class ReplicationDataHandler:
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
"""
@ -304,7 +308,7 @@ class ReplicationDataHandler:
"Finished waiting for repl stream %r to reach %s", stream_name, position
)
def stop_pusher(self, user_id, app_id, pushkey):
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return
@ -316,13 +320,13 @@ class ReplicationDataHandler:
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
async def start_pusher(self, user_id, app_id, pushkey):
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
class FederationSenderHandler:
@ -353,10 +357,12 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
def wake_destination(self, server: str):
def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server)
async def process_replication_rows(self, stream_name, token, rows):
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
if stream_name == "federation":
@ -384,11 +390,12 @@ class FederationSenderHandler:
for host in hosts:
self.federation_sender.send_device_messages(host)
async def _on_new_receipts(self, rows):
async def _on_new_receipts(
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
) -> None:
"""
Args:
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
new receipts to be processed
rows: new receipts to be processed
"""
for receipt in rows:
# we only want to send on receipts for our own users
@ -408,7 +415,7 @@ class FederationSenderHandler:
)
await self.federation_sender.send_read_receipt(receipt_info)
async def update_token(self, token):
async def update_token(self, token: int) -> None:
"""Update the record of where we have processed to in the federation stream.
Called after we have processed a an update received over replication. Sends
@ -428,7 +435,7 @@ class FederationSenderHandler:
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
async def _save_and_send_ack(self):
async def _save_and_send_ack(self) -> None:
"""Save the current federation position in the database and send an ACK
to master with where we're up to.
"""

View file

@ -18,12 +18,15 @@ allowed to be sent by which side.
"""
import abc
import logging
from typing import Tuple, Type
from typing import Optional, Tuple, Type, TypeVar
from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Command")
class Command(metaclass=abc.ABCMeta):
"""The base command class.
@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def from_line(cls, line):
def from_line(cls: Type[T], line: str) -> T:
"""Deserialises a line from the wire into this command. `line` does not
include the command.
"""
@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
prefix.
"""
def get_logcontext_id(self):
def get_logcontext_id(self) -> str:
"""Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name.
return self.NAME
SC = TypeVar("SC", bound="_SimpleCommand")
class _SimpleCommand(Command):
"""An implementation of Command whose argument is just a 'data' string."""
def __init__(self, data):
def __init__(self, data: str):
self.data = data
@classmethod
def from_line(cls, line):
def from_line(cls: Type[SC], line: str) -> SC:
return cls(line)
def to_line(self) -> str:
@ -109,14 +115,16 @@ class RdataCommand(Command):
NAME = "RDATA"
def __init__(self, stream_name, instance_name, token, row):
def __init__(
self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
stream_name,
@ -125,7 +133,7 @@ class RdataCommand(Command):
json_decoder.decode(row_json),
)
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@ -135,7 +143,7 @@ class RdataCommand(Command):
)
)
def get_logcontext_id(self):
def get_logcontext_id(self) -> str:
return "RDATA-" + self.stream_name
@ -164,18 +172,20 @@ class PositionCommand(Command):
NAME = "POSITION"
def __init__(self, stream_name, instance_name, prev_token, new_token):
def __init__(
self, stream_name: str, instance_name: str, prev_token: int, new_token: int
):
self.stream_name = stream_name
self.instance_name = instance_name
self.prev_token = prev_token
self.new_token = new_token
@classmethod
def from_line(cls, line):
def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.stream_name,
@ -218,14 +228,14 @@ class ReplicateCommand(Command):
NAME = "REPLICATE"
def __init__(self):
def __init__(self) -> None:
pass
@classmethod
def from_line(cls, line):
def from_line(cls: Type[T], line: str) -> T:
return cls()
def to_line(self):
def to_line(self) -> str:
return ""
@ -247,14 +257,16 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
def __init__(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
@ -262,7 +274,7 @@ class UserSyncCommand(Command):
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
def to_line(self) -> str:
return " ".join(
(
self.instance_id,
@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id):
def __init__(self, instance_id: str):
self.instance_id = instance_id
@classmethod
def from_line(cls, line):
def from_line(
cls: Type["ClearUserSyncsCommand"], line: str
) -> "ClearUserSyncsCommand":
return cls(line)
def to_line(self):
def to_line(self) -> str:
return self.instance_id
@ -316,7 +330,9 @@ class FederationAckCommand(Command):
self.token = token
@classmethod
def from_line(cls, line: str) -> "FederationAckCommand":
def from_line(
cls: Type["FederationAckCommand"], line: str
) -> "FederationAckCommand":
instance_name, token = line.split(" ")
return cls(instance_name, int(token))
@ -334,7 +350,15 @@ class UserIpCommand(Command):
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
def __init__(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
self.user_id = user_id
self.access_token = access_token
self.ip = ip
@ -343,14 +367,14 @@ class UserIpCommand(Command):
self.last_seen = last_seen
@classmethod
def from_line(cls, line):
def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self):
def to_line(self) -> str:
return (
self.user_id
+ " "

View file

@ -261,7 +261,7 @@ class ReplicationCommandHandler:
"process-replication-data", self._unsafe_process_queue, stream_name
)
async def _unsafe_process_queue(self, stream_name: str):
async def _unsafe_process_queue(self, stream_name: str) -> None:
"""Processes the command queue for the given stream, until it is empty
Does not check if there is already a thread processing the queue, hence "unsafe"
@ -294,7 +294,7 @@ class ReplicationCommandHandler:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
def start_replication(self, hs: "HomeServer"):
def start_replication(self, hs: "HomeServer") -> None:
"""Helper method to start a replication connection to the remote server
using TCP.
"""
@ -345,10 +345,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: IReplicationConnection):
def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
"""Send current position of all streams this process is source of to
the connection.
"""
@ -392,7 +392,7 @@ class ReplicationCommandHandler:
def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand
):
) -> None:
federation_ack_counter.inc()
if self._federation_sender:
@ -408,7 +408,7 @@ class ReplicationCommandHandler:
else:
return None
async def _handle_user_ip(self, cmd: UserIpCommand):
async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
@ -421,7 +421,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@ -497,7 +497,7 @@ class ReplicationCommandHandler:
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Called to handle a batch of replication data with a given stream token.
Args:
@ -512,7 +512,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None:
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@ -581,7 +581,7 @@ class ReplicationCommandHandler:
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@ -604,7 +604,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: IReplicationConnection):
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
@ -631,7 +631,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
def lost_connection(self, connection: IReplicationConnection):
def lost_connection(self, connection: IReplicationConnection) -> None:
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@ -653,7 +653,7 @@ class ReplicationCommandHandler:
def send_command(
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
) -> None:
"""Send a command to all connected connections.
Args:
@ -680,7 +680,7 @@ class ReplicationCommandHandler:
else:
logger.warning("Dropping command as not connected: %r", cmd.NAME)
def send_federation_ack(self, token: int):
def send_federation_ack(self, token: int) -> None:
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
@ -688,7 +688,7 @@ class ReplicationCommandHandler:
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
) -> None:
"""Poke the master that a user has started/stopped syncing."""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
@ -702,15 +702,15 @@ class ReplicationCommandHandler:
user_agent: str,
device_id: str,
last_seen: int,
):
) -> None:
"""Tell the master that the user made a request."""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
def send_remote_server_up(self, server: str) -> None:
self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not

View file

@ -49,7 +49,7 @@ import fcntl
import logging
import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, Collection, List, Optional
from typing import TYPE_CHECKING, Any, Collection, List, Optional
from prometheus_client import Counter
from zope.interface import Interface, implementer
@ -123,7 +123,7 @@ class ConnectionStates:
class IReplicationConnection(Interface):
"""An interface for replication connections."""
def send_command(cmd: Command):
def send_command(cmd: Command) -> None:
"""Send the command down the connection"""
@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-conn", self.conn_id
)
def connectionMade(self):
def connectionMade(self) -> None:
logger.info("[%s] Connection established", self.id())
self.state = ConnectionStates.ESTABLISHED
@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Always send the initial PING so that the other side knows that they
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
self.send_command(PingCommand(str(self.clock.time_msec())))
self.command_handler.new_connection(self)
def send_ping(self):
def send_ping(self) -> None:
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
"""
@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now))
self.send_command(PingCommand(str(now)))
if (
self.received_ping
@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
)
self.send_error("ping timeout")
def lineReceived(self, line: bytes):
def lineReceived(self, line: bytes) -> None:
"""Called when we've received a line"""
with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_line(line)
def _parse_and_dispatch_line(self, line: bytes):
def _parse_and_dispatch_line(self, line: bytes) -> None:
if line.strip() == "":
# Ignore blank lines
return
@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if not handled:
logger.warning("Unhandled command: %r", cmd)
def close(self):
def close(self) -> None:
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()
def send_error(self, error_string, *args):
def send_error(self, error_string: str, *args: Any) -> None:
"""Send an error to remote and close the connection."""
self.send_command(ErrorCommand(error_string % args))
self.close()
def send_command(self, cmd, do_buffer=True):
def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
"""Send a command if connection has been established.
Args:
cmd (Command)
do_buffer (bool): Whether to buffer the message or always attempt
cmd
do_buffer: Whether to buffer the message or always attempt
to send the command. This is mostly used to send an error
message if we're about to close the connection due our buffers
becoming full.
@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_sent_command = self.clock.time_msec()
def _queue_command(self, cmd):
def _queue_command(self, cmd: Command) -> None:
"""Queue the command until the connection is ready to write to again."""
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
self.close()
def _send_pending_commands(self):
def _send_pending_commands(self) -> None:
"""Send any queued commandes"""
pending = self.pending_commands
self.pending_commands = []
for cmd in pending:
self.send_command(cmd)
def on_PING(self, line):
def on_PING(self, cmd: PingCommand) -> None:
self.received_ping = True
def on_ERROR(self, cmd):
def on_ERROR(self, cmd: ErrorCommand) -> None:
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
def pauseProducing(self) -> None:
"""This is called when both the kernel send buffer and the twisted
tcp connection send buffers have become full.
@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info("[%s] Pause producing", self.id())
self.state = ConnectionStates.PAUSED
def resumeProducing(self):
def resumeProducing(self) -> None:
"""The remote has caught up after we started buffering!"""
logger.info("[%s] Resume producing", self.id())
self.state = ConnectionStates.ESTABLISHED
self._send_pending_commands()
def stopProducing(self):
def stopProducing(self) -> None:
"""We're never going to send any more data (normally because either
we or the remote has closed the connection)
"""
logger.info("[%s] Stop producing", self.id())
self.on_connection_closed()
def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable]
try:
# Remove us from list of connections to be monitored
@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.on_connection_closed()
def on_connection_closed(self):
def on_connection_closed(self) -> None:
logger.info("[%s] Connection was closed", self.id())
self.state = ConnectionStates.CLOSED
@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# the sentinel context is now active, which may not be correct.
# PreserveLoggingContext() will restore the correct logging context.
def __str__(self):
def __str__(self) -> str:
addr = None
if self.transport:
addr = str(self.transport.getPeer())
@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
addr,
)
def id(self):
def id(self) -> str:
return "%s-%s" % (self.name, self.conn_id)
def lineLengthExceeded(self, line):
def lineLengthExceeded(self, line: str) -> None:
"""Called when we receive a line that is above the maximum line length"""
self.send_error("Line length exceeded")
@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
def connectionMade(self):
def connectionMade(self) -> None:
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
def on_NAME(self, cmd):
def on_NAME(self, cmd: NameCommand) -> None:
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.client_name = client_name
self.server_name = server_name
def connectionMade(self):
def connectionMade(self) -> None:
self.send_command(NameCommand(self.client_name))
super().connectionMade()
# Once we've connected subscribe to the necessary streams
self.replicate()
def on_SERVER(self, cmd):
def on_SERVER(self, cmd: ServerCommand) -> None:
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
def replicate(self):
def replicate(self) -> None:
"""Send the subscription request to the server"""
logger.info("[%s] Subscribing to replication streams", self.id())
@ -529,7 +529,7 @@ pending_commands = LaterGauge(
)
def transport_buffer_size(protocol):
def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
if protocol.transport:
size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
return size
@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
)
def transport_kernel_read_buffer_size(protocol, read=True):
def transport_kernel_read_buffer_size(
protocol: BaseReplicationStreamProtocol, read: bool = True
) -> int:
SIOCINQ = 0x541B
SIOCOUTQ = 0x5411

View file

@ -14,7 +14,7 @@
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]):
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant
def __set__(self, obj: Optional[T], value: V):
def __set__(self, obj: Optional[T], value: V) -> None:
pass
@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
synapse_stream_name: str
synapse_outbound_redis_connection: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
# a logcontext which we use for processing incoming commands. We declare it as a
@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
"replication_command_handler"
)
def connectionMade(self):
def connectionMade(self) -> None:
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
async def _send_subscribe(self):
async def _send_subscribe(self) -> None:
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# otherside won't know we've connected and so won't issue a REPLICATE.
self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
def messageReceived(self, pattern: str, channel: str, message: str) -> None:
"""Received a message from redis."""
with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_message(message)
def _parse_and_dispatch_message(self, message: str):
def _parse_and_dispatch_message(self, message: str) -> None:
if message.strip() == "":
# Ignore blank lines
return
@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
"replication-" + cmd.get_logcontext_id(), lambda: res
)
def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.synapse_handler.lost_connection(self)
@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# the sentinel context is now active, which may not be correct.
# PreserveLoggingContext() will restore the correct logging context.
def send_command(self, cmd: Command):
def send_command(self, cmd: Command) -> None:
"""Send a command if connection has been established.
Args:
cmd (Command)
cmd: The command to send
"""
run_as_background_process(
"send-cmd", self._async_send_command, cmd, bg_start_span=False
)
async def _async_send_command(self, cmd: Command):
async def _async_send_command(self, cmd: Command) -> None:
"""Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
async def _send_ping(self):
async def _send_ping(self) -> None:
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
# it's rubbish. We add our own here.
def startedConnecting(self, connector: IConnector):
def startedConnecting(self, connector: IConnector) -> None:
logger.info(
"Connecting to redis server %s", format_address(connector.getDestination())
)
super().startedConnecting(connector)
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.info(
"Connection to redis server %s failed: %s",
format_address(connector.getDestination()),
@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
)
super().clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector: IConnector, reason: Failure):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.info(
"Connection to redis server %s lost: %s",
format_address(connector.getDestination()),
@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)

View file

@ -16,16 +16,18 @@
import logging
import random
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
from twisted.internet.interfaces import IAddress
from twisted.internet.protocol import ServerFactory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.replication.tcp.streams._base import StreamRow, Token
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory):
# listener config again or always starting a `ReplicationStreamer`.)
hs.get_replication_streamer()
def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol:
return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.command_handler
)
@ -105,7 +107,7 @@ class ReplicationStreamer:
if any(EventsStream.NAME == s.NAME for s in self.streams):
self.clock.looping_call(self.on_notifier_poke, 1000)
def on_notifier_poke(self):
def on_notifier_poke(self) -> None:
"""Checks if there is actually any new data and sends it to the
connections if there are.
@ -137,7 +139,7 @@ class ReplicationStreamer:
run_as_background_process("replication_notifier", self._run_notifier_loop)
async def _run_notifier_loop(self):
async def _run_notifier_loop(self) -> None:
self.is_looping = True
try:
@ -238,7 +240,9 @@ class ReplicationStreamer:
self.is_looping = False
def _batch_updates(updates):
def _batch_updates(
updates: List[Tuple[Token, StreamRow]]
) -> List[Tuple[Optional[Token], StreamRow]]:
"""Takes a list of updates of form [(token, row)] and sets the token to
None for all rows where the next row has the same token. This is used to
implement batching.
@ -254,7 +258,7 @@ def _batch_updates(updates):
if not updates:
return []
new_updates = []
new_updates: List[Tuple[Optional[Token], StreamRow]] = []
for i, update in enumerate(updates[:-1]):
if update[0] == updates[i + 1][0]:
new_updates.append((None, update[1]))

View file

@ -90,7 +90,7 @@ class Stream:
ROW_TYPE: Any = None
@classmethod
def parse_row(cls, row: StreamRow):
def parse_row(cls, row: StreamRow) -> Any:
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@ -139,7 +139,7 @@ class Stream:
# The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
@ -200,7 +200,7 @@ def current_token_without_instance(
return lambda instance_name: current_token()
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""

View file

@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Tuple, Type
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr
from ._base import Stream, StreamUpdateResult, Token
from synapse.replication.tcp.streams._base import (
Stream,
StreamRow,
StreamUpdateResult,
Token,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -58,6 +62,9 @@ class EventsStreamRow:
data: "BaseEventsStreamRow"
T = TypeVar("T", bound="BaseEventsStreamRow")
class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream.
@ -68,7 +75,7 @@ class BaseEventsStreamRow:
TypeId: str
@classmethod
def from_data(cls, data):
def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
"""Parse the data from the replication stream into a row.
By default we just call the constructor with the data list as arguments
@ -221,7 +228,7 @@ class EventsStream(Stream):
return updates, upper_limit, limited
@classmethod
def parse_row(cls, row):
(typ, data) = row
data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, data)
def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
(typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
event_stream_row_data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, event_stream_row_data)

View file

@ -16,8 +16,7 @@ import itertools
import re
import secrets
import string
from collections.abc import Iterable
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple
from netaddr import valid_ipv6
@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
Otherwise, return the stringification of a a list with the first maxitems items,
Otherwise, return the stringification of a list with the first maxitems items,
followed by "...".
Args:

View file

@ -14,6 +14,7 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None
IPv4Address("TCP", "127.0.0.1", 0)
)
# Make a new HomeServer object for the worker
@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.clock,
repl_handler,
)
server = self.server_factory.buildProtocol(None)
server = self.server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)

View file

@ -14,6 +14,7 @@
from typing import Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
"""Create a new direct TCP replication connection"""
proto = self.factory.buildProtocol(("127.0.0.1", 0))
proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
transport = StringTransport()
proto.makeConnection(transport)