diff --git a/changelog.d/10888.misc b/changelog.d/10888.misc new file mode 100644 index 0000000000..d9c9917881 --- /dev/null +++ b/changelog.d/10888.misc @@ -0,0 +1 @@ +Improve type hinting in `synapse.util`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 568166db33..86459bdcb6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -102,9 +102,27 @@ disallow_untyped_defs = True [mypy-synapse.util.batching_queue] disallow_untyped_defs = True +[mypy-synapse.util.caches.cached_call] +disallow_untyped_defs = True + [mypy-synapse.util.caches.dictionary_cache] disallow_untyped_defs = True +[mypy-synapse.util.caches.lrucache] +disallow_untyped_defs = True + +[mypy-synapse.util.caches.response_cache] +disallow_untyped_defs = True + +[mypy-synapse.util.caches.stream_change_cache] +disallow_untyped_defs = True + +[mypy-synapse.util.caches.ttl_cache] +disallow_untyped_defs = True + +[mypy-synapse.util.daemonize] +disallow_untyped_defs = True + [mypy-synapse.util.file_consumer] disallow_untyped_defs = True @@ -141,6 +159,9 @@ disallow_untyped_defs = True [mypy-synapse.util.msisdn] disallow_untyped_defs = True +[mypy-synapse.util.patch_inline_callbacks] +disallow_untyped_defs = True + [mypy-synapse.util.ratelimitutils] disallow_untyped_defs = True @@ -162,6 +183,9 @@ disallow_untyped_defs = True [mypy-synapse.util.wheel_timer] disallow_untyped_defs = True +[mypy-synapse.util.versionstring] +disallow_untyped_defs = True + [mypy-tests.handlers.test_user_directory] disallow_untyped_defs = True diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index e58dd91eda..470f4f91a5 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -85,7 +85,7 @@ class CachedCall(Generic[TV]): # result in the deferred, since `awaiting` a deferred destroys its result. # (Also, if it's a Failure, GCing the deferred would log a critical error # about unhandled Failures) - def got_result(r): + def got_result(r: Union[TV, Failure]) -> None: self._result = r self._deferred.addBoth(got_result) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 6262efe072..da502aec11 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -31,6 +31,7 @@ from prometheus_client import Gauge from twisted.internet import defer from twisted.python import failure +from twisted.python.failure import Failure from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache @@ -112,7 +113,7 @@ class DeferredCache(Generic[KT, VT]): self.thread: Optional[threading.Thread] = None @property - def max_entries(self): + def max_entries(self) -> int: return self.cache.max_size def check_thread(self) -> None: @@ -258,7 +259,7 @@ class DeferredCache(Generic[KT, VT]): return False - def cb(result) -> None: + def cb(result: VT) -> None: if compare_and_pop(): self.cache.set(key, result, entry.callbacks) else: @@ -270,7 +271,7 @@ class DeferredCache(Generic[KT, VT]): # not have been. Either way, let's double-check now. entry.invalidate() - def eb(_fail) -> None: + def eb(_fail: Failure) -> None: compare_and_pop() entry.invalidate() @@ -284,11 +285,11 @@ class DeferredCache(Generic[KT, VT]): def prefill( self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None - ): + ) -> None: callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) - def invalidate(self, key): + def invalidate(self, key) -> None: """Delete a key, or tree of entries If the cache is backed by a regular dict, then "key" must be of diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4ff62b403f..a0a7a9de32 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) try: from pympler.asizeof import Asizer - def _get_size_of(val: Any, *, recurse=True) -> int: + def _get_size_of(val: Any, *, recurse: bool = True) -> int: """Get an estimate of the size in bytes of the object. Args: @@ -71,7 +71,7 @@ try: except ImportError: - def _get_size_of(val: Any, *, recurse=True) -> int: + def _get_size_of(val: Any, *, recurse: bool = True) -> int: return 0 @@ -85,15 +85,6 @@ VT = TypeVar("VT") # a general type var, distinct from either KT or VT T = TypeVar("T") - -def enumerate_leaves(node, depth): - if depth == 0: - yield node - else: - for n in node.values(): - yield from enumerate_leaves(n, depth - 1) - - P = TypeVar("P") @@ -102,7 +93,7 @@ class _TimedListNode(ListNode[P]): __slots__ = ["last_access_ts_secs"] - def update_last_access(self, clock: Clock): + def update_last_access(self, clock: Clock) -> None: self.last_access_ts_secs = int(clock.time()) @@ -115,7 +106,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node() @wrap_as_background_process("LruCache._expire_old_entries") -async def _expire_old_entries(clock: Clock, expiry_seconds: int): +async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: """Walks the global cache list to find cache entries that haven't been accessed in the given number of seconds. """ @@ -163,7 +154,7 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int): logger.info("Dropped %d items from caches", i) -def setup_expire_lru_cache_entries(hs: "HomeServer"): +def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: """Start a background job that expires all cache entries if they have not been accessed for the given number of seconds. """ @@ -183,7 +174,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer"): ) -class _Node: +class _Node(Generic[KT, VT]): __slots__ = [ "_list_node", "_global_list_node", @@ -197,8 +188,8 @@ class _Node: def __init__( self, root: "ListNode[_Node]", - key, - value, + key: KT, + value: VT, cache: "weakref.ReferenceType[LruCache]", clock: Clock, callbacks: Collection[Callable[[], None]] = (), @@ -409,7 +400,7 @@ class LruCache(Generic[KT, VT]): def synchronized(f: FT) -> FT: @wraps(f) - def inner(*args, **kwargs): + def inner(*args: Any, **kwargs: Any) -> Any: with lock: return f(*args, **kwargs) @@ -418,17 +409,19 @@ class LruCache(Generic[KT, VT]): cached_cache_len = [0] if size_callback is not None: - def cache_len(): + def cache_len() -> int: return cached_cache_len[0] else: - def cache_len(): + def cache_len() -> int: return len(cache) self.len = synchronized(cache_len) - def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()): + def add_node( + key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () + ) -> None: node = _Node( list_root, key, @@ -446,7 +439,7 @@ class LruCache(Generic[KT, VT]): if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) - def move_node_to_front(node: _Node): + def move_node_to_front(node: _Node) -> None: node.move_to_front(real_clock, list_root) def delete_node(node: _Node) -> int: @@ -488,7 +481,7 @@ class LruCache(Generic[KT, VT]): default: Optional[T] = None, callbacks: Collection[Callable[[], None]] = (), update_metrics: bool = True, - ): + ) -> Union[None, T, VT]: node = cache.get(key, None) if node is not None: move_node_to_front(node) @@ -502,7 +495,9 @@ class LruCache(Generic[KT, VT]): return default @synchronized - def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()): + def cache_set( + key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = () + ) -> None: node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -547,7 +542,7 @@ class LruCache(Generic[KT, VT]): ... @synchronized - def cache_pop(key: KT, default: Optional[T] = None): + def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: delete_node(node) @@ -612,25 +607,25 @@ class LruCache(Generic[KT, VT]): self.contains = cache_contains self.clear = cache_clear - def __getitem__(self, key): + def __getitem__(self, key: KT) -> VT: result = self.get(key, self.sentinel) if result is self.sentinel: raise KeyError() else: - return result + return cast(VT, result) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: self.set(key, value) - def __delitem__(self, key, value): + def __delitem__(self, key: KT, value: VT) -> None: result = self.pop(key, self.sentinel) if result is self.sentinel: raise KeyError() - def __len__(self): + def __len__(self) -> int: return self.len() - def __contains__(self, key): + def __contains__(self, key: KT) -> bool: return self.contains(key) def set_cache_factor(self, factor: float) -> bool: diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index ed7204336f..88ccf44337 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -104,8 +104,8 @@ class ResponseCache(Generic[KV]): return None def _set( - self, context: ResponseCacheContext[KV], deferred: defer.Deferred - ) -> defer.Deferred: + self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]" + ) -> "defer.Deferred[RV]": """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -126,7 +126,7 @@ class ResponseCache(Generic[KV]): key = context.cache_key self.pending_result_cache[key] = result - def on_complete(r): + def on_complete(r: RV) -> RV: # if this cache has a non-zero timeout, and the callback has not cleared # the should_cache bit, we leave it in the cache for now and schedule # its removal later. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 27b1da235e..330709b8b7 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -40,10 +40,10 @@ class StreamChangeCache: self, name: str, current_stream_pos: int, - max_size=10000, + max_size: int = 10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, - ): - self._original_max_size = max_size + ) -> None: + self._original_max_size: int = max_size self._max_size = math.floor(max_size) self._entity_to_key: Dict[EntityType, int] = {} diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 46afe3f934..0b9ac26b69 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -159,12 +159,12 @@ class TTLCache(Generic[KT, VT]): del self._expiry_list[0] -@attr.s(frozen=True, slots=True) -class _CacheEntry: +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313 """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. - expiry_time = attr.ib(type=float) - ttl = attr.ib(type=float) - key = attr.ib() - value = attr.ib() + expiry_time: float + ttl: float + key: Any # should be KT + value: Any # should be VT diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index f1a351cfd4..de04f34e4e 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -19,6 +19,8 @@ import logging import os import signal import sys +from types import FrameType, TracebackType +from typing import NoReturn, Type def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: @@ -97,7 +99,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # (we don't normally expect reactor.run to raise any exceptions, but this will # also catch any other uncaught exceptions before we get that far.) - def excepthook(type_, value, traceback): + def excepthook( + type_: Type[BaseException], value: BaseException, traceback: TracebackType + ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) sys.excepthook = excepthook @@ -119,7 +123,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum, frame): + def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon." % signum) sys.exit(0) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 1b82dca81b..1e784b3f1f 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -14,9 +14,11 @@ import logging from functools import wraps -from typing import Any, Callable, Optional, TypeVar, cast +from types import TracebackType +from typing import Any, Callable, Optional, Type, TypeVar, cast from prometheus_client import Counter +from typing_extensions import Protocol from synapse.logging.context import ( ContextResourceUsage, @@ -24,6 +26,7 @@ from synapse.logging.context import ( current_context, ) from synapse.metrics import InFlightGauge +from synapse.util import Clock logger = logging.getLogger(__name__) @@ -64,6 +67,10 @@ in_flight = InFlightGauge( T = TypeVar("T", bound=Callable[..., Any]) +class HasClock(Protocol): + clock: Clock + + def measure_func(name: Optional[str] = None) -> Callable[[T], T]: """ Used to decorate an async function with a `Measure` context manager. @@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]: block_name = func.__name__ if name is None else name @wraps(func) - async def measured_func(self, *args, **kwargs): + async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any: with Measure(self.clock, block_name): r = await func(self, *args, **kwargs) return r @@ -104,10 +111,10 @@ class Measure: "start", ] - def __init__(self, clock, name: str): + def __init__(self, clock: Clock, name: str) -> None: """ Args: - clock: A n object with a "time()" method, which returns the current + clock: An object with a "time()" method, which returns the current time in seconds. name: The name of the metric to report. """ @@ -124,7 +131,7 @@ class Measure: assert isinstance(curr_context, LoggingContext) parent_context = curr_context self._logging_context = LoggingContext(str(curr_context), parent_context) - self.start: Optional[int] = None + self.start: Optional[float] = None def __enter__(self) -> "Measure": if self.start is not None: @@ -138,7 +145,12 @@ class Measure: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: if self.start is None: raise RuntimeError("Measure() block exited without being entered") @@ -168,8 +180,9 @@ class Measure: """ return self._logging_context.get_resource_usage() - def _update_in_flight(self, metrics): + def _update_in_flight(self, metrics) -> None: """Gets called when processing in flight metrics""" + assert self.start is not None duration = self.clock.time() - self.start metrics.real_time_max = max(metrics.real_time_max, duration) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 9dd010af3b..1f18654d47 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, List +from typing import Any, Callable, Generator, List, TypeVar from twisted.internet import defer from twisted.internet.defer import Deferred @@ -24,6 +24,9 @@ from twisted.python.failure import Failure _already_patched = False +T = TypeVar("T") + + def do_patch() -> None: """ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit @@ -37,15 +40,19 @@ def do_patch() -> None: if _already_patched: return - def new_inline_callbacks(f): + def new_inline_callbacks( + f: Callable[..., Generator["Deferred[object]", object, T]] + ) -> Callable[..., "Deferred[T]"]: @functools.wraps(f) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": start_context = current_context() changes: List[str] = [] - orig = orig_inline_callbacks(_check_yield_points(f, changes)) + orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( + _check_yield_points(f, changes) + ) try: - res = orig(*args, **kwargs) + res: "Deferred[T]" = orig(*args, **kwargs) except Exception: if current_context() != start_context: for err in changes: @@ -84,7 +91,7 @@ def do_patch() -> None: print(err, file=sys.stderr) raise Exception(err) - def check_ctx(r): + def check_ctx(r: T) -> T: if current_context() != start_context: for err in changes: print(err, file=sys.stderr) @@ -107,7 +114,10 @@ def do_patch() -> None: _already_patched = True -def _check_yield_points(f: Callable, changes: List[str]) -> Callable: +def _check_yield_points( + f: Callable[..., Generator["Deferred[object]", object, T]], + changes: List[str], +) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks checking that after every yield the log contexts are correct. @@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable: from synapse.logging.context import current_context @functools.wraps(f) - def check_yield_points_inner(*args, **kwargs): + def check_yield_points_inner( + *args: Any, **kwargs: Any + ) -> Generator["Deferred[object]", object, T]: gen = f(*args, **kwargs) last_yield_line_no = gen.gi_frame.f_lineno diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 1c20b24bbe..899ee0adc8 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -15,14 +15,18 @@ import logging import os import subprocess +from types import ModuleType +from typing import Dict logger = logging.getLogger(__name__) +version_cache: Dict[ModuleType, str] = {} -def get_version_string(module) -> str: + +def get_version_string(module: ModuleType) -> str: """Given a module calculate a git-aware version string for it. - If called on a module not in a git checkout will return `__verison__`. + If called on a module not in a git checkout will return `__version__`. Args: module (module) @@ -31,11 +35,13 @@ def get_version_string(module) -> str: str """ - cached_version = getattr(module, "_synapse_version_string_cache", None) - if cached_version: + cached_version = version_cache.get(module) + if cached_version is not None: return cached_version - version_string = module.__version__ + # We want this to fail loudly with an AttributeError. Type-ignore this so + # mypy only considers the happy path. + version_string = module.__version__ # type: ignore[attr-defined] try: null = open(os.devnull, "w") @@ -97,10 +103,15 @@ def get_version_string(module) -> str: s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - version_string = "%s (%s)" % (module.__version__, git_version) + version_string = "%s (%s)" % ( + # If the __version__ attribute doesn't exist, we'll have failed + # loudly above. + module.__version__, # type: ignore[attr-defined] + git_version, + ) except Exception as e: logger.info("Failed to check for git repository: %s", e) - module._synapse_version_string_cache = version_string + version_cache[module] = version_string return version_string