Add type hints to _base.

This commit is contained in:
Patrick Cloke 2021-11-09 15:58:14 -05:00
parent a2c8921242
commit 70384c7907
7 changed files with 99 additions and 61 deletions

View file

@ -26,7 +26,6 @@ exclude = (?x)
|synapse/_scripts/register_new_matrix_user.py
|synapse/_scripts/review_recent_signups.py
|synapse/app/__init__.py
|synapse/app/_base.py
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py

View file

@ -22,13 +22,24 @@ import socket
import sys
import traceback
import warnings
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Tuple,
)
from cryptography.utils import CryptographyDeprecationWarning
from typing_extensions import NoReturn
import twisted
from twisted.internet import defer, error, reactor
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP
from twisted.internet.protocol import ServerFactory
from twisted.internet.tcp import Port
from twisted.logger import LoggingFile, LogLevel
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.threadpool import ThreadPool
@ -60,30 +71,37 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict
_sighup_callbacks = []
_sighup_callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
def register_sighup(func, *args, **kwargs):
def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
"""
Register a function to be called when a SIGHUP occurs.
Args:
func (function): Function to be called when sent a SIGHUP signal.
func: Function to be called when sent a SIGHUP signal.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append((func, args, kwargs))
def start_worker_reactor(appname, config, run_command=reactor.run):
def start_worker_reactor(
appname: str,
config: HomeServerConfig,
# mypy can't figure out what the default reactor is.
run_command: Callable[[], None] = reactor.run, # type: ignore[attr-defined]
) -> None:
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor. Pulls configuration from the 'worker' settings in 'config'.
Args:
appname (str): application name which will be sent to syslog
config (synapse.config.Config): config object
run_command (Callable[]): callable that actually runs the reactor
appname: application name which will be sent to syslog
config: config object
run_command: callable that actually runs the reactor
"""
logger = logging.getLogger(config.worker.worker_app)
@ -101,32 +119,33 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
def start_reactor(
appname,
soft_file_limit,
gc_thresholds,
pid_file,
daemonize,
print_pidfile,
logger,
run_command=reactor.run,
):
appname: str,
soft_file_limit: int,
gc_thresholds: Tuple[int, int, int],
pid_file: str,
daemonize: bool,
print_pidfile: bool,
logger: logging.Logger,
# mypy can't figure out what the default reactor is.
run_command: Callable[[], None] = reactor.run, # type: ignore[attr-defined]
) -> None:
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor
Args:
appname (str): application name which will be sent to syslog
soft_file_limit (int):
appname: application name which will be sent to syslog
soft_file_limit:
gc_thresholds:
pid_file (str): name of pid file to write to if daemonize is True
daemonize (bool): true to run the reactor in a background process
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
run_command (Callable[]): callable that actually runs the reactor
pid_file: name of pid file to write to if daemonize is True
daemonize: true to run the reactor in a background process
print_pidfile: whether to print the pid file, if daemonize is True
logger: logger instance to pass to Daemonize
run_command: callable that actually runs the reactor
"""
def run():
def run() -> None:
logger.info("Running")
setup_jemalloc_stats()
change_resource_limit(soft_file_limit)
@ -151,7 +170,7 @@ def start_reactor(
run()
def quit_with_error(error_string: str) -> NoReturn:
def quit_with_error(error_string: str) -> None:
message_lines = error_string.split("\n")
line_length = min(max(len(line) for line in message_lines), 80) + 2
sys.stderr.write("*" * line_length + "\n")
@ -161,7 +180,7 @@ def quit_with_error(error_string: str) -> NoReturn:
sys.exit(1)
def handle_startup_exception(e: Exception) -> NoReturn:
def handle_startup_exception(e: Exception) -> None:
# Exceptions that occur between setting up the logging and forking or starting
# the reactor are written to the logs, followed by a summary to stderr.
logger.exception("Exception during startup")
@ -185,7 +204,7 @@ def redirect_stdio_to_logs() -> None:
print("Redirected stdout/stderr to logs")
def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
"""Register a callback with the reactor, to be called once it is running
This can be used to initialise parts of the system which require an asynchronous
@ -195,7 +214,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
will exit.
"""
async def wrapper():
async def wrapper() -> None:
try:
await cb(*args, **kwargs)
except Exception:
@ -221,10 +240,11 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
# on as normal.
os._exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
# mypy can't figure out what the default reactor is.
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) # type: ignore[attr-defined]
def listen_metrics(bind_addresses, port):
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
"""
Start Prometheus metrics server.
"""
@ -240,7 +260,7 @@ def listen_manhole(
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
):
) -> None:
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
# suppress the warning for now.
@ -259,12 +279,19 @@ def listen_manhole(
)
def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
def listen_tcp(
bind_addresses: Iterable[str],
port: int,
factory: ServerFactory,
# mypy can't figure out what the default reactor is.
reactor: IReactorTCP = reactor, # type: ignore[assignment]
backlog: int = 50,
) -> List[Port]:
"""
Create a TCP socket for a port and several addresses
Returns:
list[twisted.internet.tcp.Port]: listening for TCP connections
list of twisted.internet.tcp.Port listening for TCP connections
"""
r = []
for address in bind_addresses:
@ -273,12 +300,20 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
return r
# IReactorTCP returns an object implementing IListeningPort from listenTCP,
# but we know it will be a Port instance.
return r # type: ignore[return-value]
def listen_ssl(
bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
):
bind_addresses: Iterable[str],
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,
# mypy can't figure out what the default reactor is.
reactor: IReactorSSL = reactor, # type: ignore[assignment]
backlog: int = 50,
) -> List[Port]:
"""
Create an TLS-over-TCP socket for a port and several addresses
@ -294,10 +329,13 @@ def listen_ssl(
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
return r
# IReactorSSL incorrectly declares that an int is returned from listenSSL,
# it actually returns an object implementing IListeningPort, but we know it
# will be a Port instance.
return r # type: ignore[return-value]
def refresh_certificate(hs: "HomeServer"):
def refresh_certificate(hs: "HomeServer") -> None:
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
@ -329,7 +367,7 @@ def refresh_certificate(hs: "HomeServer"):
logger.info("Context factories updated.")
async def start(hs: "HomeServer"):
async def start(hs: "HomeServer") -> None:
"""
Start a Synapse server or worker.
@ -360,7 +398,7 @@ async def start(hs: "HomeServer"):
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
@ -373,7 +411,7 @@ async def start(hs: "HomeServer"):
# We defer running the sighup handlers until next reactor tick. This
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
def run_sighup(*args: Any, **kwargs: Any) -> None:
# `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
@ -436,12 +474,8 @@ async def start(hs: "HomeServer"):
atexit.register(gc.freeze)
def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration
Args:
hs
"""
def setup_sentry(hs: "HomeServer") -> None:
"""Enable sentry integration, if enabled in configuration"""
if not hs.config.metrics.sentry_enabled:
return
@ -466,7 +500,7 @@ def setup_sentry(hs: "HomeServer"):
scope.set_tag("worker_name", name)
def setup_sdnotify(hs: "HomeServer"):
def setup_sdnotify(hs: "HomeServer") -> None:
"""Adds process state hooks to tell systemd what we are up to."""
# Tell systemd our state, if we're using it. This will silently fail if
@ -481,7 +515,7 @@ def setup_sdnotify(hs: "HomeServer"):
sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
def sdnotify(state):
def sdnotify(state: bytes) -> None:
"""
Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
@ -490,7 +524,7 @@ def sdnotify(state):
package which many OSes don't include as a matter of principle.
Args:
state (bytes): notification to send
state: notification to send
"""
if not isinstance(state, bytes):
raise TypeError("sdnotify should be called with a bytes")

View file

@ -149,6 +149,8 @@ class SynapseHomeServer(HomeServer):
)
if tls:
# refresh_certificate should have been called before this.
assert self.tls_server_context_factory is not None
ports = listen_ssl(
bind_addresses,
port,

View file

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
from prometheus_client import Counter
from twisted.internet.protocol import Factory
from twisted.internet.protocol import ServerFactory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand
@ -38,7 +38,7 @@ stream_updates_counter = Counter(
logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
class ReplicationStreamProtocolFactory(ServerFactory):
"""Factory for new replication connections."""
def __init__(self, hs: "HomeServer"):

View file

@ -33,8 +33,8 @@ from typing import (
cast,
)
import twisted.internet.tcp
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
from twisted.web.iweb import IPolicyForHTTPS
from twisted.web.resource import Resource
@ -207,7 +207,7 @@ class HomeServer(metaclass=abc.ABCMeta):
Attributes:
config (synapse.config.homeserver.HomeserverConfig):
_listening_services (list[twisted.internet.tcp.Port]): TCP ports that
_listening_services (list[Port]): TCP ports that
we are listening on to provide HTTP services.
"""
@ -250,7 +250,7 @@ class HomeServer(metaclass=abc.ABCMeta):
# the key we use to sign events and requests
self.signing_key = config.key.signing_key[0]
self.config = config
self._listening_services: List[twisted.internet.tcp.Port] = []
self._listening_services: List[Port] = []
self.start_time: Optional[int] = None
self._instance_id = random_string(5)

View file

@ -38,6 +38,7 @@ from zope.interface import Interface
from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorSSL,
IReactorTCP,
IReactorThreads,
IReactorTime,
@ -66,6 +67,7 @@ JsonDict = Dict[str, Any]
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP,
IReactorSSL,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,

View file

@ -23,7 +23,7 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal
from twisted.internet import defer
from twisted.internet.protocol import Factory
from twisted.internet.protocol import ServerFactory
from synapse.config.server import ManholeConfig
@ -65,7 +65,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----"""
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
"""Starts a ssh listener with password authentication using
the given username and password. Clients connecting to the ssh
listener will find themselves in a colored python shell with
@ -105,7 +105,8 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
return factory
# ConchFactory is a Factory, not a ServerFactory, but they are identical.
return factory # type: ignore[return-value]
class SynapseManhole(ColoredManhole):