forked from MirrorHub/synapse
Add most missing type hints to synapse.util (#11328)
This commit is contained in:
parent
3a1462f7e0
commit
7468723697
10 changed files with 161 additions and 165 deletions
1
changelog.d/11328.misc
Normal file
1
changelog.d/11328.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to `synapse.util`.
|
87
mypy.ini
87
mypy.ini
|
@ -196,92 +196,11 @@ disallow_untyped_defs = True
|
|||
[mypy-synapse.streams.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.batching_queue]
|
||||
[mypy-synapse.util.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.cached_call]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.dictionary_cache]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.lrucache]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.response_cache]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.stream_change_cache]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.caches.ttl_cache]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.daemonize]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.file_consumer]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.frozenutils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.hash]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.httpresourcetree]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.iterutils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.linked_list]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.logcontext]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.logformatter]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.macaroons]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.manhole]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.module_loader]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.msisdn]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.patch_inline_callbacks]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.ratelimitutils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.retryutils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.rlimit]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.stringutils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.templates]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.threepids]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.wheel_timer]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.util.versionstring]
|
||||
disallow_untyped_defs = True
|
||||
[mypy-synapse.util.caches.treecache]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-tests.handlers.test_user_directory]
|
||||
disallow_untyped_defs = True
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
|||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
|
@ -40,7 +41,6 @@ from typing_extensions import ContextManager
|
|||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.python import failure
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import (
|
||||
|
@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
|
|||
object.__setattr__(self, "_result", None)
|
||||
object.__setattr__(self, "_observers", [])
|
||||
|
||||
def callback(r):
|
||||
def callback(r: _T) -> _T:
|
||||
object.__setattr__(self, "_result", (True, r))
|
||||
|
||||
# once we have set _result, no more entries will be added to _observers,
|
||||
|
@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
|
|||
)
|
||||
return r
|
||||
|
||||
def errback(f):
|
||||
def errback(f: Failure) -> Optional[Failure]:
|
||||
object.__setattr__(self, "_result", (False, f))
|
||||
|
||||
# once we have set _result, no more entries will be added to _observers,
|
||||
|
@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
|
|||
for observer in observers:
|
||||
# This is a little bit of magic to correctly propagate stack
|
||||
# traces when we `await` on one of the observer deferreds.
|
||||
f.value.__failure__ = f
|
||||
f.value.__failure__ = f # type: ignore[union-attr]
|
||||
try:
|
||||
observer.errback(f)
|
||||
except Exception as e:
|
||||
|
@ -314,7 +314,7 @@ class Linearizer:
|
|||
# will release the lock.
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager(_):
|
||||
def _ctx_manager(_: None) -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -355,7 +355,7 @@ class Linearizer:
|
|||
new_defer = make_deferred_yieldable(defer.Deferred())
|
||||
entry.deferreds[new_defer] = 1
|
||||
|
||||
def cb(_r):
|
||||
def cb(_r: None) -> "defer.Deferred[None]":
|
||||
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
entry.count += 1
|
||||
|
||||
|
@ -371,7 +371,7 @@ class Linearizer:
|
|||
# code must be synchronous, so this is the only sensible place.)
|
||||
return self._clock.sleep(0)
|
||||
|
||||
def eb(e):
|
||||
def eb(e: Failure) -> Failure:
|
||||
logger.info("defer %r got err %r", new_defer, e)
|
||||
if isinstance(e, CancelledError):
|
||||
logger.debug(
|
||||
|
@ -435,7 +435,7 @@ class ReadWriteLock:
|
|||
await make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -464,7 +464,7 @@ class ReadWriteLock:
|
|||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -524,7 +524,7 @@ def timeout_deferred(
|
|||
|
||||
delayed_call = reactor.callLater(timeout, time_it_out)
|
||||
|
||||
def convert_cancelled(value: failure.Failure):
|
||||
def convert_cancelled(value: Failure) -> Failure:
|
||||
# if the original deferred was cancelled, and our timeout has fired, then
|
||||
# the reason it was cancelled was due to our timeout. Turn the CancelledError
|
||||
# into a TimeoutError.
|
||||
|
@ -534,7 +534,7 @@ def timeout_deferred(
|
|||
|
||||
deferred.addErrback(convert_cancelled)
|
||||
|
||||
def cancel_timeout(result):
|
||||
def cancel_timeout(result: _T) -> _T:
|
||||
# stop the pending call to cancel the deferred if it's been fired
|
||||
if delayed_call.active():
|
||||
delayed_call.cancel()
|
||||
|
@ -542,11 +542,11 @@ def timeout_deferred(
|
|||
|
||||
deferred.addBoth(cancel_timeout)
|
||||
|
||||
def success_cb(val):
|
||||
def success_cb(val: _T) -> None:
|
||||
if not new_d.called:
|
||||
new_d.callback(val)
|
||||
|
||||
def failure_cb(val):
|
||||
def failure_cb(val: Failure) -> None:
|
||||
if not new_d.called:
|
||||
new_d.errback(val)
|
||||
|
||||
|
@ -557,13 +557,13 @@ def timeout_deferred(
|
|||
|
||||
# This class can't be generic because it uses slots with attrs.
|
||||
# See: https://github.com/python-attrs/attrs/issues/313
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DoneAwaitable: # should be: Generic[R]
|
||||
"""Simple awaitable that returns the provided value."""
|
||||
|
||||
value = attr.ib(type=Any) # should be: R
|
||||
value: Any # should be: R
|
||||
|
||||
def __await__(self):
|
||||
def __await__(self) -> Any:
|
||||
return self
|
||||
|
||||
def __iter__(self) -> "DoneAwaitable":
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import typing
|
||||
from enum import Enum, auto
|
||||
from sys import intern
|
||||
from typing import Callable, Dict, Optional, Sized
|
||||
from typing import Any, Callable, Dict, List, Optional, Sized
|
||||
|
||||
import attr
|
||||
from prometheus_client.core import Gauge
|
||||
|
@ -58,20 +58,20 @@ class EvictionReason(Enum):
|
|||
time = auto()
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class CacheMetric:
|
||||
|
||||
_cache = attr.ib()
|
||||
_cache_type = attr.ib(type=str)
|
||||
_cache_name = attr.ib(type=str)
|
||||
_collect_callback = attr.ib(type=Optional[Callable])
|
||||
_cache: Sized
|
||||
_cache_type: str
|
||||
_cache_name: str
|
||||
_collect_callback: Optional[Callable]
|
||||
|
||||
hits = attr.ib(default=0)
|
||||
misses = attr.ib(default=0)
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
|
||||
factory=collections.Counter
|
||||
)
|
||||
memory_usage = attr.ib(default=None)
|
||||
memory_usage: Optional[int] = None
|
||||
|
||||
def inc_hits(self) -> None:
|
||||
self.hits += 1
|
||||
|
@ -89,13 +89,14 @@ class CacheMetric:
|
|||
self.memory_usage += memory
|
||||
|
||||
def dec_memory_usage(self, memory: int) -> None:
|
||||
assert self.memory_usage is not None
|
||||
self.memory_usage -= memory
|
||||
|
||||
def clear_memory_usage(self) -> None:
|
||||
if self.memory_usage is not None:
|
||||
self.memory_usage = 0
|
||||
|
||||
def describe(self):
|
||||
def describe(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def collect(self) -> None:
|
||||
|
@ -118,8 +119,9 @@ class CacheMetric:
|
|||
self.eviction_size_by_reason[reason]
|
||||
)
|
||||
cache_total.labels(self._cache_name).set(self.hits + self.misses)
|
||||
if getattr(self._cache, "max_size", None):
|
||||
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
|
||||
max_size = getattr(self._cache, "max_size", None)
|
||||
if max_size:
|
||||
cache_max_size.labels(self._cache_name).set(max_size)
|
||||
|
||||
if TRACK_MEMORY_USAGE:
|
||||
# self.memory_usage can be None if nothing has been inserted
|
||||
|
@ -193,7 +195,7 @@ KNOWN_KEYS = {
|
|||
}
|
||||
|
||||
|
||||
def intern_string(string):
|
||||
def intern_string(string: Optional[str]) -> Optional[str]:
|
||||
"""Takes a (potentially) unicode string and interns it if it's ascii"""
|
||||
if string is None:
|
||||
return None
|
||||
|
@ -204,7 +206,7 @@ def intern_string(string):
|
|||
return string
|
||||
|
||||
|
||||
def intern_dict(dictionary):
|
||||
def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Takes a dictionary and interns well known keys and their values"""
|
||||
return {
|
||||
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
|
||||
|
@ -212,7 +214,7 @@ def intern_dict(dictionary):
|
|||
}
|
||||
|
||||
|
||||
def _intern_known_values(key, value):
|
||||
def _intern_known_values(key: str, value: Any) -> Any:
|
||||
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
|
||||
|
||||
if key in intern_keys:
|
||||
|
|
|
@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
|
|||
callbacks = [callback] if callback else []
|
||||
self.cache.set(key, value, callbacks=callbacks)
|
||||
|
||||
def invalidate(self, key) -> None:
|
||||
def invalidate(self, key: KT) -> None:
|
||||
"""Delete a key, or tree of entries
|
||||
|
||||
If the cache is backed by a regular dict, then "key" must be of
|
||||
|
|
|
@ -19,12 +19,15 @@ import logging
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
|
@ -32,6 +35,7 @@ from typing import (
|
|||
from weakref import WeakValueDictionary
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
|
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
|
|||
|
||||
|
||||
class _CacheDescriptorBase:
|
||||
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[..., Any],
|
||||
num_args: Optional[int],
|
||||
cache_context: bool = False,
|
||||
):
|
||||
self.orig = orig
|
||||
|
||||
arg_spec = inspect.getfullargspec(orig)
|
||||
|
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
orig,
|
||||
orig: Callable[..., Any],
|
||||
max_entries: int = 1000,
|
||||
cache_context: bool = False,
|
||||
):
|
||||
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||
self.max_entries = max_entries
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||
cache: LruCache[CacheKey, Any] = LruCache(
|
||||
cache_name=self.orig.__name__,
|
||||
max_size=self.max_entries,
|
||||
|
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|||
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||
|
||||
|
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
return r1 + r2
|
||||
|
||||
Args:
|
||||
num_args (int): number of positional arguments (excluding ``self`` and
|
||||
num_args: number of positional arguments (excluding ``self`` and
|
||||
``cache_context``) to use as cache keys. Defaults to all named
|
||||
args of the function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orig,
|
||||
max_entries=1000,
|
||||
num_args=None,
|
||||
tree=False,
|
||||
cache_context=False,
|
||||
iterable=False,
|
||||
orig: Callable[..., Any],
|
||||
max_entries: int = 1000,
|
||||
num_args: Optional[int] = None,
|
||||
tree: bool = False,
|
||||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
prune_unread_entries: bool = True,
|
||||
):
|
||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||
|
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
self.iterable = iterable
|
||||
self.prune_unread_entries = prune_unread_entries
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
|
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|||
get_cache_key = self.cache_key_builder
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(*args, **kwargs):
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
of results.
|
||||
"""
|
||||
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=None):
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[..., Any],
|
||||
cached_method_name: str,
|
||||
list_name: str,
|
||||
num_args: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
orig (function)
|
||||
cached_method_name (str): The name of the cached method.
|
||||
list_name (str): Name of the argument which is the bulk lookup list
|
||||
num_args (int): number of positional arguments (excluding ``self``,
|
||||
orig
|
||||
cached_method_name: The name of the cached method.
|
||||
list_name: Name of the argument which is the bulk lookup list
|
||||
num_args: number of positional arguments (excluding ``self``,
|
||||
but including list_name) to use as cache keys. Defaults to all
|
||||
named args of the function.
|
||||
"""
|
||||
|
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
% (self.list_name, cached_method_name)
|
||||
)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
def __get__(
|
||||
self, obj: Optional[Any], objtype: Optional[Type] = None
|
||||
) -> Callable[..., Any]:
|
||||
cached_method = getattr(obj, self.cached_method_name)
|
||||
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
||||
num_args = cached_method.num_args
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
# If we're passed a cache_context then we'll want to call its
|
||||
# invalidate() whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
|
||||
results = {}
|
||||
|
||||
def update_results_dict(res, arg):
|
||||
def update_results_dict(res: Any, arg: Hashable) -> None:
|
||||
results[arg] = res
|
||||
|
||||
# list of deferreds to wait for
|
||||
|
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
# otherwise a tuple is used.
|
||||
if num_args == 1:
|
||||
|
||||
def arg_to_cache_key(arg):
|
||||
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||
return arg
|
||||
|
||||
else:
|
||||
keylist = list(keyargs)
|
||||
|
||||
def arg_to_cache_key(arg):
|
||||
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||
keylist[self.list_pos] = arg
|
||||
return tuple(keylist)
|
||||
|
||||
|
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
key = arg_to_cache_key(arg)
|
||||
cache.set(key, deferred, callback=invalidate_callback)
|
||||
|
||||
def complete_all(res):
|
||||
def complete_all(res: Dict[Hashable, Any]) -> None:
|
||||
# the wrapped function has completed. It returns a
|
||||
# a dict. We can now resolve the observable deferreds in
|
||||
# the cache and update our own result map.
|
||||
|
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|||
deferreds_map[e].callback(val)
|
||||
results[e] = val
|
||||
|
||||
def errback(f):
|
||||
def errback(f: Failure) -> Failure:
|
||||
# the wrapped function has failed. Invalidate any cache
|
||||
# entries we're supposed to be populating, and fail
|
||||
# their deferreds.
|
||||
|
|
|
@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
|
|||
import attr
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
|
@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
|
|||
# Don't bother starting the loop if things never expire
|
||||
return
|
||||
|
||||
def f():
|
||||
def f() -> "defer.Deferred[None]":
|
||||
return run_as_background_process(
|
||||
"prune_cache_%s" % self._cache_name, self._prune_cache
|
||||
)
|
||||
|
@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
|
|||
return False
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _CacheEntry:
|
||||
time = attr.ib(type=int)
|
||||
value = attr.ib()
|
||||
time: int
|
||||
value: Any
|
||||
|
|
|
@ -18,12 +18,13 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def user_left_room(distributor, user, room_id):
|
||||
def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
|
||||
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||
|
||||
|
||||
|
@ -63,7 +64,7 @@ class Distributor:
|
|||
self.pre_registration[name] = []
|
||||
self.pre_registration[name].append(observer)
|
||||
|
||||
def fire(self, name: str, *args, **kwargs) -> None:
|
||||
def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||
"""Dispatches the given signal to the registered observers.
|
||||
|
||||
Runs the observers as a background process. Does not return a deferred.
|
||||
|
@ -95,7 +96,7 @@ class Signal:
|
|||
Each observer callable may return a Deferred."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
|
||||
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
|
||||
"""Invokes every callable in the observer list, passing in the args and
|
||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||
not an error to fire a signal with no observers.
|
||||
|
@ -103,7 +104,7 @@ class Signal:
|
|||
Returns a Deferred that will complete when all the observers have
|
||||
completed."""
|
||||
|
||||
async def do(observer):
|
||||
async def do(observer: Callable[..., Any]) -> Any:
|
||||
try:
|
||||
return await maybe_awaitable(observer(*args, **kwargs))
|
||||
except Exception as e:
|
||||
|
@ -120,5 +121,5 @@ class Signal:
|
|||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "<Signal name=%r>" % (self.name,)
|
||||
|
|
|
@ -3,23 +3,52 @@
|
|||
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
|
||||
# private class.
|
||||
|
||||
|
||||
from socket import (
|
||||
AF_INET,
|
||||
AF_INET6,
|
||||
AF_UNSPEC,
|
||||
SOCK_DGRAM,
|
||||
SOCK_STREAM,
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
gaierror,
|
||||
getaddrinfo,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.interfaces import IHostnameResolver, IHostResolution
|
||||
from twisted.internet.interfaces import (
|
||||
IAddress,
|
||||
IHostnameResolver,
|
||||
IHostResolution,
|
||||
IReactorThreads,
|
||||
IResolutionReceiver,
|
||||
)
|
||||
from twisted.internet.threads import deferToThreadPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# The types below are copied from
|
||||
# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
|
||||
# so that the type hints can match the interfaces.
|
||||
from twisted.python.runtime import platform
|
||||
|
||||
if platform.supportsThreads():
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
else:
|
||||
ThreadPool = object # type: ignore[misc, assignment]
|
||||
|
||||
|
||||
@implementer(IHostResolution)
|
||||
class HostResolution:
|
||||
|
@ -27,13 +56,13 @@ class HostResolution:
|
|||
The in-progress resolution of a given hostname.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Create a L{HostResolution} with the given name.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> NoReturn:
|
||||
# IHostResolution.cancel
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -62,6 +91,17 @@ _socktypeToType = {
|
|||
}
|
||||
|
||||
|
||||
_GETADDRINFO_RESULT = List[
|
||||
Tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
Union[Tuple[str, int], Tuple[str, int, int, int]],
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@implementer(IHostnameResolver)
|
||||
class GAIResolver:
|
||||
"""
|
||||
|
@ -69,7 +109,12 @@ class GAIResolver:
|
|||
L{getaddrinfo} in a thread.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
|
||||
def __init__(
|
||||
self,
|
||||
reactor: IReactorThreads,
|
||||
getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
|
||||
getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
|
||||
):
|
||||
"""
|
||||
Create a L{GAIResolver}.
|
||||
@param reactor: the reactor to schedule result-delivery on
|
||||
|
@ -89,14 +134,16 @@ class GAIResolver:
|
|||
)
|
||||
self._getaddrinfo = getaddrinfo
|
||||
|
||||
def resolveHostName(
|
||||
# The types on IHostnameResolver is incorrect in Twisted, see
|
||||
# https://twistedmatrix.com/trac/ticket/10276
|
||||
def resolveHostName( # type: ignore[override]
|
||||
self,
|
||||
resolutionReceiver,
|
||||
hostName,
|
||||
portNumber=0,
|
||||
addressTypes=None,
|
||||
transportSemantics="TCP",
|
||||
):
|
||||
resolutionReceiver: IResolutionReceiver,
|
||||
hostName: str,
|
||||
portNumber: int = 0,
|
||||
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
|
||||
transportSemantics: str = "TCP",
|
||||
) -> IHostResolution:
|
||||
"""
|
||||
See L{IHostnameResolver.resolveHostName}
|
||||
@param resolutionReceiver: see interface
|
||||
|
@ -112,7 +159,7 @@ class GAIResolver:
|
|||
]
|
||||
socketType = _transportToSocket[transportSemantics]
|
||||
|
||||
def get():
|
||||
def get() -> _GETADDRINFO_RESULT:
|
||||
try:
|
||||
return self._getaddrinfo(
|
||||
hostName, portNumber, addressFamily, socketType
|
||||
|
@ -125,7 +172,7 @@ class GAIResolver:
|
|||
resolutionReceiver.resolutionBegan(resolution)
|
||||
|
||||
@d.addCallback
|
||||
def deliverResults(result):
|
||||
def deliverResults(result: _GETADDRINFO_RESULT) -> None:
|
||||
for family, socktype, _proto, _cannoname, sockaddr in result:
|
||||
addrType = _afToType[family]
|
||||
resolutionReceiver.addressResolved(
|
||||
|
|
|
@ -64,6 +64,13 @@ in_flight = InFlightGauge(
|
|||
sub_metrics=["real_time_max", "real_time_sum"],
|
||||
)
|
||||
|
||||
|
||||
# This is dynamically created in InFlightGauge.__init__.
|
||||
class _InFlightMetric(Protocol):
|
||||
real_time_max: float
|
||||
real_time_sum: float
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
|
@ -180,7 +187,7 @@ class Measure:
|
|||
"""
|
||||
return self._logging_context.get_resource_usage()
|
||||
|
||||
def _update_in_flight(self, metrics) -> None:
|
||||
def _update_in_flight(self, metrics: _InFlightMetric) -> None:
|
||||
"""Gets called when processing in flight metrics"""
|
||||
assert self.start is not None
|
||||
duration = self.clock.time() - self.start
|
||||
|
|
Loading…
Reference in a new issue