diff --git a/changelog.d/12667.misc b/changelog.d/12667.misc new file mode 100644 index 000000000..2b17502d6 --- /dev/null +++ b/changelog.d/12667.misc @@ -0,0 +1 @@ +Use `ParamSpec` to refine type hints. diff --git a/poetry.lock b/poetry.lock index ddafaaeba..f649efdf2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1563,7 +1563,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "eebc9e1d720e2e866f5fddda98ce83d858949a6fdbe30c7e5aef4cf9d17be498" +content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index 4c51b8c4a..2c4b7eb08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,7 +143,9 @@ netaddr = ">=0.7.18" Jinja2 = ">=3.0" bleach = ">=1.4.3" # We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0. -typing-extensions = ">=3.10.0" +# Additionally we need https://github.com/python/typing/pull/817 to allow types to be +# generic over ParamSpecs. +typing-extensions = ">=3.10.0.1" # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. cryptography = ">=3.4.7" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index d28b87a3f..3623c1724 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -38,6 +38,7 @@ from typing import ( from cryptography.utils import CryptographyDeprecationWarning from matrix_common.versionstring import get_distribution_version_string +from typing_extensions import ParamSpec import twisted from twisted.internet import defer, error, reactor as _reactor @@ -81,11 +82,12 @@ logger = logging.getLogger(__name__) # list of tuples of function, args list, kwargs dict _sighup_callbacks: List[ - Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] + Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]] ] = [] +P = ParamSpec("P") -def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None: +def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: """ Register a function to be called when a SIGHUP occurs. @@ -93,7 +95,9 @@ def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> Non func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append((func, args, kwargs)) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] def start_worker_reactor( @@ -214,7 +218,9 @@ def redirect_stdio_to_logs() -> None: print("Redirected stdout/stderr to logs") -def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None: +def register_start( + cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs +) -> None: """Register a callback with the reactor, to be called once it is running This can be used to initialise parts of the system which require an asynchronous diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 98555c8c0..8437ce52d 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -22,9 +22,12 @@ from typing import ( List, Optional, Set, + TypeVar, Union, ) +from typing_extensions import ParamSpec + from synapse.api.presence import UserPresenceState from synapse.util.async_helpers import maybe_awaitable @@ -40,6 +43,10 @@ GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]] logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") + + def load_legacy_presence_router(hs: "HomeServer") -> None: """Wrapper that loads a presence router module configured using the old configuration, and registers the hooks they implement. @@ -63,13 +70,15 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed - def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + def async_wrapper( + f: Optional[Callable[P, R]] + ) -> Optional[Callable[P, Awaitable[R]]]: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. if f is None: return None - def run(*args: Any, **kwargs: Any) -> Awaitable: + def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: # Assertion required because mypy can't prove we won't change `f` # back to `None`. See # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions @@ -80,7 +89,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: return run # Register the hooks through the module API. - hooks = { + hooks: Dict[str, Optional[Callable[..., Any]]] = { hook: async_wrapper(getattr(presence_router, hook, None)) for hook in presence_router_methods } diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 834fe1b62..73f92d2df 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -30,6 +30,7 @@ from typing import ( import attr import jinja2 +from typing_extensions import ParamSpec from twisted.internet import defer from twisted.web.resource import Resource @@ -129,6 +130,7 @@ if TYPE_CHECKING: T = TypeVar("T") +P = ParamSpec("P") """ This package defines the 'stable' API which can be used by extension modules which @@ -799,9 +801,9 @@ class ModuleApi: def run_db_interaction( self, desc: str, - func: Callable[..., T], - *args: Any, - **kwargs: Any, + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, ) -> "defer.Deferred[T]": """Run a function with a database connection @@ -817,8 +819,9 @@ class ModuleApi: Returns: Deferred[object]: result of func """ + # type-ignore: See https://github.com/python/mypy/issues/8862 return defer.ensureDeferred( - self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) def complete_sso_login( @@ -1296,9 +1299,9 @@ class ModuleApi: async def defer_to_thread( self, - f: Callable[..., T], - *args: Any, - **kwargs: Any, + f: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, ) -> T: """Runs the given function in a separate thread from Synapse's thread pool. diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 0152a0c66..ad025c8a4 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple -from twisted.web.server import Request - from synapse.api.constants import Membership from synapse.api.errors import SynapseError from synapse.http.server import HttpServer @@ -97,7 +95,7 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} def on_PUT( - self, request: Request, room_identifier: str, txn_id: str + self, request: SynapseRequest, room_identifier: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 914fb3acf..61375651b 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -15,7 +15,9 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple + +from typing_extensions import ParamSpec from twisted.python.failure import Failure from twisted.web.server import Request @@ -32,6 +34,9 @@ logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins +P = ParamSpec("P") + + class HttpTransactionCache: def __init__(self, hs: "HomeServer"): self.hs = hs @@ -65,9 +70,9 @@ class HttpTransactionCache: def fetch_or_execute_request( self, request: Request, - fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], - *args: Any, - **kwargs: Any, + fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], + *args: P.args, + **kwargs: P.kwargs, ) -> Awaitable[Tuple[int, JsonDict]]: """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -82,9 +87,9 @@ class HttpTransactionCache: def fetch_or_execute( self, txn_key: str, - fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], - *args: Any, - **kwargs: Any, + fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], + *args: P.args, + **kwargs: P.kwargs, ) -> Awaitable[Tuple[int, JsonDict]]: """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2255e55f6..41f566b64 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -192,7 +192,7 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. -_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] +_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] P = ParamSpec("P") R = TypeVar("R") @@ -239,7 +239,9 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): + def call_after( + self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs + ) -> None: """Call the given callback on the main twisted thread after the transaction has finished. @@ -256,11 +258,12 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - self.after_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def call_on_exception( - self, callback: Callable[..., object], *args: Any, **kwargs: Any - ): + self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs + ) -> None: """Call the given callback on the main twisted thread after the transaction has failed. @@ -274,7 +277,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - self.exception_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -549,9 +553,9 @@ class DatabasePool: desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], - func: Callable[..., R], - *args: Any, - **kwargs: Any, + func: Callable[Concatenate[LoggingTransaction, P], R], + *args: P.args, + **kwargs: P.kwargs, ) -> R: """Start a new database transaction with the given connection. @@ -581,7 +585,10 @@ class DatabasePool: # will fail if we have to repeat the transaction. # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - for i, arg in enumerate(args): + + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see + # https://github.com/python/mypy/pull/12668 + for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -589,7 +596,9 @@ class DatabasePool: i, func, ) - for name, val in kwargs.items(): + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see + # https://github.com/python/mypy/pull/12668 + for name, val in kwargs.items(): # type: ignore[attr-defined] if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a6c2fd47..ed29a0a5e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1648,8 +1648,12 @@ class PersistEventsStore: txn.call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: - # Invalidate the caches for the redacted event, note that these caches - # are also cleared as part of event replication in _invalidate_caches_for_event. + """Invalidate the caches for the redacted event. + + Note that these caches are also cleared as part of event replication in + _invalidate_caches_for_event. + """ + assert event.redacts is not None txn.call_after(self.store._invalidate_get_event_cache, event.redacts) txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index e27c5d298..b91020117 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -42,7 +42,7 @@ from typing import ( ) import attr -from typing_extensions import AsyncContextManager, Literal +from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -237,9 +237,16 @@ async def concurrently_execute( ) +P = ParamSpec("P") +R = TypeVar("R") + + async def yieldable_gather_results( - func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any -) -> List[T]: + func: Callable[Concatenate[T, P], Awaitable[R]], + iter: Iterable[T], + *args: P.args, + **kwargs: P.kwargs, +) -> List[R]: """Executes the function with each argument concurrently. Args: @@ -255,7 +262,15 @@ async def yieldable_gather_results( try: return await make_deferred_yieldable( defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], + # type-ignore: mypy reports two errors: + # error: Argument 1 to "run_in_background" has incompatible type + # "Callable[[T, **P], Awaitable[R]]"; expected + # "Callable[[T, **P], Awaitable[R]]" [arg-type] + # error: Argument 2 to "run_in_background" has incompatible type + # "T"; expected "[T, **P.args]" [arg-type] + # The former looks like a mypy bug, and the latter looks like a + # false positive. + [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] consumeErrors=True, ) ) @@ -577,9 +592,6 @@ class ReadWriteLock: return _ctx_manager() -R = TypeVar("R") - - def timeout_deferred( deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime ) -> "defer.Deferred[_T]": diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 91837655f..b580bdd0d 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,7 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec from twisted.internet import defer @@ -75,7 +87,11 @@ class Distributor: run_as_background_process(name, self.signals[name].fire, *args, **kwargs) -class Signal: +P = ParamSpec("P") +R = TypeVar("R") + + +class Signal(Generic[P]): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -87,16 +103,16 @@ class Signal: def __init__(self, name: str): self.name: str = name - self.observers: List[Callable] = [] + self.observers: List[Callable[P, Any]] = [] - def observe(self, observer: Callable) -> None: + def observe(self, observer: Callable[P, Any]) -> None: """Adds a new callable to the observer list which will be invoked by the 'fire' method. Each observer callable may return a Deferred.""" self.observers.append(observer) - def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": + def fire(self, *args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]": """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is not an error to fire a signal with no observers. @@ -104,7 +120,7 @@ class Signal: Returns a Deferred that will complete when all the observers have completed.""" - async def do(observer: Callable[..., Any]) -> Any: + async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]: try: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: @@ -114,6 +130,7 @@ class Signal: observer, e, ) + return None deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 98ee49af6..bc3b4938e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -15,10 +15,10 @@ import logging from functools import wraps from types import TracebackType -from typing import Any, Callable, Optional, Type, TypeVar, cast +from typing import Awaitable, Callable, Optional, Type, TypeVar from prometheus_client import Counter -from typing_extensions import Protocol +from typing_extensions import Concatenate, ParamSpec, Protocol from synapse.logging.context import ( ContextResourceUsage, @@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( ) -T = TypeVar("T", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") class HasClock(Protocol): clock: Clock -def measure_func(name: Optional[str] = None) -> Callable[[T], T]: - """ - Used to decorate an async function with a `Measure` context manager. +def measure_func( + name: Optional[str] = None, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """Decorate an async method with a `Measure` context manager. + + The Measure is created using `self.clock`; it should only be used to decorate + methods in classes defining an instance-level `clock` attribute. Usage: @@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: """ - def wrapper(func: T) -> T: + def wrapper( + func: Callable[Concatenate[HasClock, P], Awaitable[R]] + ) -> Callable[P, Awaitable[R]]: block_name = func.__name__ if name is None else name @wraps(func) - async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: + async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R: with Measure(self.clock, block_name): r = await func(self, *args, **kwargs) return r - return cast(T, measured_func) + # There are some shenanigans here, because we're decorating a method but + # explicitly making use of the `self` parameter. The key thing here is that the + # return type within the return type for `measure_func` itself describes how the + # decorated function will be called. + return measured_func # type: ignore[return-value] - return wrapper + return wrapper # type: ignore[return-value] class Measure: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index dace68666..f97f98a05 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -16,6 +16,8 @@ import functools import sys from typing import Any, Callable, Generator, List, TypeVar, cast +from typing_extensions import ParamSpec + from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure @@ -25,6 +27,7 @@ _already_patched = False T = TypeVar("T") +P = ParamSpec("P") def do_patch() -> None: @@ -41,13 +44,13 @@ def do_patch() -> None: return def new_inline_callbacks( - f: Callable[..., Generator["Deferred[object]", object, T]] - ) -> Callable[..., "Deferred[T]"]: + f: Callable[P, Generator["Deferred[object]", object, T]] + ) -> Callable[P, "Deferred[T]"]: @functools.wraps(f) - def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": + def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": start_context = current_context() changes: List[str] = [] - orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( + orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks( _check_yield_points(f, changes) ) @@ -115,7 +118,7 @@ def do_patch() -> None: def _check_yield_points( - f: Callable[..., Generator["Deferred[object]", object, T]], + f: Callable[P, Generator["Deferred[object]", object, T]], changes: List[str], ) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks @@ -138,7 +141,7 @@ def _check_yield_points( @functools.wraps(f) def check_yield_points_inner( - *args: Any, **kwargs: Any + *args: P.args, **kwargs: P.kwargs ) -> Generator["Deferred[object]", object, T]: gen = f(*args, **kwargs)