forked from MirrorHub/synapse
Add most missing type hints to synapse.util (#11328)
This commit is contained in:
parent
3a1462f7e0
commit
7468723697
10 changed files with 161 additions and 165 deletions
1
changelog.d/11328.misc
Normal file
1
changelog.d/11328.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `synapse.util`.
|
87
mypy.ini
87
mypy.ini
|
@ -196,92 +196,11 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.streams.*]
|
[mypy-synapse.streams.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.batching_queue]
|
[mypy-synapse.util.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.util.caches.cached_call]
|
[mypy-synapse.util.caches.treecache]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
[mypy-synapse.util.caches.dictionary_cache]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.caches.lrucache]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.caches.response_cache]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.caches.stream_change_cache]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.caches.ttl_cache]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.daemonize]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.file_consumer]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.frozenutils]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.hash]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.httpresourcetree]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.iterutils]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.linked_list]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.logcontext]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.logformatter]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.macaroons]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.manhole]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.module_loader]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.msisdn]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.patch_inline_callbacks]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.ratelimitutils]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.retryutils]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.rlimit]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.stringutils]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.templates]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.threepids]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.wheel_timer]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-synapse.util.versionstring]
|
|
||||||
disallow_untyped_defs = True
|
|
||||||
|
|
||||||
[mypy-tests.handlers.test_user_directory]
|
[mypy-tests.handlers.test_user_directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
Generic,
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
Iterator,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
@ -40,7 +41,6 @@ from typing_extensions import ContextManager
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.python import failure
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
|
@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
|
||||||
object.__setattr__(self, "_result", None)
|
object.__setattr__(self, "_result", None)
|
||||||
object.__setattr__(self, "_observers", [])
|
object.__setattr__(self, "_observers", [])
|
||||||
|
|
||||||
def callback(r):
|
def callback(r: _T) -> _T:
|
||||||
object.__setattr__(self, "_result", (True, r))
|
object.__setattr__(self, "_result", (True, r))
|
||||||
|
|
||||||
# once we have set _result, no more entries will be added to _observers,
|
# once we have set _result, no more entries will be added to _observers,
|
||||||
|
@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
|
||||||
)
|
)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def errback(f):
|
def errback(f: Failure) -> Optional[Failure]:
|
||||||
object.__setattr__(self, "_result", (False, f))
|
object.__setattr__(self, "_result", (False, f))
|
||||||
|
|
||||||
# once we have set _result, no more entries will be added to _observers,
|
# once we have set _result, no more entries will be added to _observers,
|
||||||
|
@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
|
||||||
for observer in observers:
|
for observer in observers:
|
||||||
# This is a little bit of magic to correctly propagate stack
|
# This is a little bit of magic to correctly propagate stack
|
||||||
# traces when we `await` on one of the observer deferreds.
|
# traces when we `await` on one of the observer deferreds.
|
||||||
f.value.__failure__ = f
|
f.value.__failure__ = f # type: ignore[union-attr]
|
||||||
try:
|
try:
|
||||||
observer.errback(f)
|
observer.errback(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -314,7 +314,7 @@ class Linearizer:
|
||||||
# will release the lock.
|
# will release the lock.
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager(_):
|
def _ctx_manager(_: None) -> Iterator[None]:
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -355,7 +355,7 @@ class Linearizer:
|
||||||
new_defer = make_deferred_yieldable(defer.Deferred())
|
new_defer = make_deferred_yieldable(defer.Deferred())
|
||||||
entry.deferreds[new_defer] = 1
|
entry.deferreds[new_defer] = 1
|
||||||
|
|
||||||
def cb(_r):
|
def cb(_r: None) -> "defer.Deferred[None]":
|
||||||
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
|
||||||
entry.count += 1
|
entry.count += 1
|
||||||
|
|
||||||
|
@ -371,7 +371,7 @@ class Linearizer:
|
||||||
# code must be synchronous, so this is the only sensible place.)
|
# code must be synchronous, so this is the only sensible place.)
|
||||||
return self._clock.sleep(0)
|
return self._clock.sleep(0)
|
||||||
|
|
||||||
def eb(e):
|
def eb(e: Failure) -> Failure:
|
||||||
logger.info("defer %r got err %r", new_defer, e)
|
logger.info("defer %r got err %r", new_defer, e)
|
||||||
if isinstance(e, CancelledError):
|
if isinstance(e, CancelledError):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -435,7 +435,7 @@ class ReadWriteLock:
|
||||||
await make_deferred_yieldable(curr_writer)
|
await make_deferred_yieldable(curr_writer)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager() -> Iterator[None]:
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -464,7 +464,7 @@ class ReadWriteLock:
|
||||||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager() -> Iterator[None]:
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
@ -524,7 +524,7 @@ def timeout_deferred(
|
||||||
|
|
||||||
delayed_call = reactor.callLater(timeout, time_it_out)
|
delayed_call = reactor.callLater(timeout, time_it_out)
|
||||||
|
|
||||||
def convert_cancelled(value: failure.Failure):
|
def convert_cancelled(value: Failure) -> Failure:
|
||||||
# if the original deferred was cancelled, and our timeout has fired, then
|
# if the original deferred was cancelled, and our timeout has fired, then
|
||||||
# the reason it was cancelled was due to our timeout. Turn the CancelledError
|
# the reason it was cancelled was due to our timeout. Turn the CancelledError
|
||||||
# into a TimeoutError.
|
# into a TimeoutError.
|
||||||
|
@ -534,7 +534,7 @@ def timeout_deferred(
|
||||||
|
|
||||||
deferred.addErrback(convert_cancelled)
|
deferred.addErrback(convert_cancelled)
|
||||||
|
|
||||||
def cancel_timeout(result):
|
def cancel_timeout(result: _T) -> _T:
|
||||||
# stop the pending call to cancel the deferred if it's been fired
|
# stop the pending call to cancel the deferred if it's been fired
|
||||||
if delayed_call.active():
|
if delayed_call.active():
|
||||||
delayed_call.cancel()
|
delayed_call.cancel()
|
||||||
|
@ -542,11 +542,11 @@ def timeout_deferred(
|
||||||
|
|
||||||
deferred.addBoth(cancel_timeout)
|
deferred.addBoth(cancel_timeout)
|
||||||
|
|
||||||
def success_cb(val):
|
def success_cb(val: _T) -> None:
|
||||||
if not new_d.called:
|
if not new_d.called:
|
||||||
new_d.callback(val)
|
new_d.callback(val)
|
||||||
|
|
||||||
def failure_cb(val):
|
def failure_cb(val: Failure) -> None:
|
||||||
if not new_d.called:
|
if not new_d.called:
|
||||||
new_d.errback(val)
|
new_d.errback(val)
|
||||||
|
|
||||||
|
@ -557,13 +557,13 @@ def timeout_deferred(
|
||||||
|
|
||||||
# This class can't be generic because it uses slots with attrs.
|
# This class can't be generic because it uses slots with attrs.
|
||||||
# See: https://github.com/python-attrs/attrs/issues/313
|
# See: https://github.com/python-attrs/attrs/issues/313
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class DoneAwaitable: # should be: Generic[R]
|
class DoneAwaitable: # should be: Generic[R]
|
||||||
"""Simple awaitable that returns the provided value."""
|
"""Simple awaitable that returns the provided value."""
|
||||||
|
|
||||||
value = attr.ib(type=Any) # should be: R
|
value: Any # should be: R
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self) -> Any:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __iter__(self) -> "DoneAwaitable":
|
def __iter__(self) -> "DoneAwaitable":
|
||||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import typing
|
import typing
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from sys import intern
|
from sys import intern
|
||||||
from typing import Callable, Dict, Optional, Sized
|
from typing import Any, Callable, Dict, List, Optional, Sized
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client.core import Gauge
|
from prometheus_client.core import Gauge
|
||||||
|
@ -58,20 +58,20 @@ class EvictionReason(Enum):
|
||||||
time = auto()
|
time = auto()
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
class CacheMetric:
|
class CacheMetric:
|
||||||
|
|
||||||
_cache = attr.ib()
|
_cache: Sized
|
||||||
_cache_type = attr.ib(type=str)
|
_cache_type: str
|
||||||
_cache_name = attr.ib(type=str)
|
_cache_name: str
|
||||||
_collect_callback = attr.ib(type=Optional[Callable])
|
_collect_callback: Optional[Callable]
|
||||||
|
|
||||||
hits = attr.ib(default=0)
|
hits: int = 0
|
||||||
misses = attr.ib(default=0)
|
misses: int = 0
|
||||||
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
|
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
|
||||||
factory=collections.Counter
|
factory=collections.Counter
|
||||||
)
|
)
|
||||||
memory_usage = attr.ib(default=None)
|
memory_usage: Optional[int] = None
|
||||||
|
|
||||||
def inc_hits(self) -> None:
|
def inc_hits(self) -> None:
|
||||||
self.hits += 1
|
self.hits += 1
|
||||||
|
@ -89,13 +89,14 @@ class CacheMetric:
|
||||||
self.memory_usage += memory
|
self.memory_usage += memory
|
||||||
|
|
||||||
def dec_memory_usage(self, memory: int) -> None:
|
def dec_memory_usage(self, memory: int) -> None:
|
||||||
|
assert self.memory_usage is not None
|
||||||
self.memory_usage -= memory
|
self.memory_usage -= memory
|
||||||
|
|
||||||
def clear_memory_usage(self) -> None:
|
def clear_memory_usage(self) -> None:
|
||||||
if self.memory_usage is not None:
|
if self.memory_usage is not None:
|
||||||
self.memory_usage = 0
|
self.memory_usage = 0
|
||||||
|
|
||||||
def describe(self):
|
def describe(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def collect(self) -> None:
|
def collect(self) -> None:
|
||||||
|
@ -118,8 +119,9 @@ class CacheMetric:
|
||||||
self.eviction_size_by_reason[reason]
|
self.eviction_size_by_reason[reason]
|
||||||
)
|
)
|
||||||
cache_total.labels(self._cache_name).set(self.hits + self.misses)
|
cache_total.labels(self._cache_name).set(self.hits + self.misses)
|
||||||
if getattr(self._cache, "max_size", None):
|
max_size = getattr(self._cache, "max_size", None)
|
||||||
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
|
if max_size:
|
||||||
|
cache_max_size.labels(self._cache_name).set(max_size)
|
||||||
|
|
||||||
if TRACK_MEMORY_USAGE:
|
if TRACK_MEMORY_USAGE:
|
||||||
# self.memory_usage can be None if nothing has been inserted
|
# self.memory_usage can be None if nothing has been inserted
|
||||||
|
@ -193,7 +195,7 @@ KNOWN_KEYS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def intern_string(string):
|
def intern_string(string: Optional[str]) -> Optional[str]:
|
||||||
"""Takes a (potentially) unicode string and interns it if it's ascii"""
|
"""Takes a (potentially) unicode string and interns it if it's ascii"""
|
||||||
if string is None:
|
if string is None:
|
||||||
return None
|
return None
|
||||||
|
@ -204,7 +206,7 @@ def intern_string(string):
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def intern_dict(dictionary):
|
def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Takes a dictionary and interns well known keys and their values"""
|
"""Takes a dictionary and interns well known keys and their values"""
|
||||||
return {
|
return {
|
||||||
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
|
KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
|
||||||
|
@ -212,7 +214,7 @@ def intern_dict(dictionary):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _intern_known_values(key, value):
|
def _intern_known_values(key: str, value: Any) -> Any:
|
||||||
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
|
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
|
||||||
|
|
||||||
if key in intern_keys:
|
if key in intern_keys:
|
||||||
|
|
|
@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
self.cache.set(key, value, callbacks=callbacks)
|
self.cache.set(key, value, callbacks=callbacks)
|
||||||
|
|
||||||
def invalidate(self, key) -> None:
|
def invalidate(self, key: KT) -> None:
|
||||||
"""Delete a key, or tree of entries
|
"""Delete a key, or tree of entries
|
||||||
|
|
||||||
If the cache is backed by a regular dict, then "key" must be of
|
If the cache is backed by a regular dict, then "key" must be of
|
||||||
|
|
|
@ -19,12 +19,15 @@ import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
@ -32,6 +35,7 @@ from typing import (
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
|
||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase:
|
class _CacheDescriptorBase:
|
||||||
def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
orig: Callable[..., Any],
|
||||||
|
num_args: Optional[int],
|
||||||
|
cache_context: bool = False,
|
||||||
|
):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
arg_spec = inspect.getfullargspec(orig)
|
arg_spec = inspect.getfullargspec(orig)
|
||||||
|
@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
orig,
|
orig: Callable[..., Any],
|
||||||
max_entries: int = 1000,
|
max_entries: int = 1000,
|
||||||
cache_context: bool = False,
|
cache_context: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(orig, num_args=None, cache_context=cache_context)
|
super().__init__(orig, num_args=None, cache_context=cache_context)
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
|
|
||||||
def __get__(self, obj, owner):
|
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||||
cache: LruCache[CacheKey, Any] = LruCache(
|
cache: LruCache[CacheKey, Any] = LruCache(
|
||||||
cache_name=self.orig.__name__,
|
cache_name=self.orig.__name__,
|
||||||
max_size=self.max_entries,
|
max_size=self.max_entries,
|
||||||
|
@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def _wrapped(*args, **kwargs):
|
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
||||||
|
|
||||||
|
@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
return r1 + r2
|
return r1 + r2
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_args (int): number of positional arguments (excluding ``self`` and
|
num_args: number of positional arguments (excluding ``self`` and
|
||||||
``cache_context``) to use as cache keys. Defaults to all named
|
``cache_context``) to use as cache keys. Defaults to all named
|
||||||
args of the function.
|
args of the function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
orig,
|
orig: Callable[..., Any],
|
||||||
max_entries=1000,
|
max_entries: int = 1000,
|
||||||
num_args=None,
|
num_args: Optional[int] = None,
|
||||||
tree=False,
|
tree: bool = False,
|
||||||
cache_context=False,
|
cache_context: bool = False,
|
||||||
iterable=False,
|
iterable: bool = False,
|
||||||
prune_unread_entries: bool = True,
|
prune_unread_entries: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||||
|
@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
self.prune_unread_entries = prune_unread_entries
|
self.prune_unread_entries = prune_unread_entries
|
||||||
|
|
||||||
def __get__(self, obj, owner):
|
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
||||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
|
@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
get_cache_key = self.cache_key_builder
|
get_cache_key = self.cache_key_builder
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def _wrapped(*args, **kwargs):
|
def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||||
# whenever we are invalidated
|
# whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
of results.
|
of results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig, cached_method_name, list_name, num_args=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
orig: Callable[..., Any],
|
||||||
|
cached_method_name: str,
|
||||||
|
list_name: str,
|
||||||
|
num_args: Optional[int] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
orig (function)
|
orig
|
||||||
cached_method_name (str): The name of the cached method.
|
cached_method_name: The name of the cached method.
|
||||||
list_name (str): Name of the argument which is the bulk lookup list
|
list_name: Name of the argument which is the bulk lookup list
|
||||||
num_args (int): number of positional arguments (excluding ``self``,
|
num_args: number of positional arguments (excluding ``self``,
|
||||||
but including list_name) to use as cache keys. Defaults to all
|
but including list_name) to use as cache keys. Defaults to all
|
||||||
named args of the function.
|
named args of the function.
|
||||||
"""
|
"""
|
||||||
|
@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
% (self.list_name, cached_method_name)
|
% (self.list_name, cached_method_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(
|
||||||
|
self, obj: Optional[Any], objtype: Optional[Type] = None
|
||||||
|
) -> Callable[..., Any]:
|
||||||
cached_method = getattr(obj, self.cached_method_name)
|
cached_method = getattr(obj, self.cached_method_name)
|
||||||
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
||||||
num_args = cached_method.num_args
|
num_args = cached_method.num_args
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
# If we're passed a cache_context then we'll want to call its
|
# If we're passed a cache_context then we'll want to call its
|
||||||
# invalidate() whenever we are invalidated
|
# invalidate() whenever we are invalidated
|
||||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||||
|
@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
def update_results_dict(res, arg):
|
def update_results_dict(res: Any, arg: Hashable) -> None:
|
||||||
results[arg] = res
|
results[arg] = res
|
||||||
|
|
||||||
# list of deferreds to wait for
|
# list of deferreds to wait for
|
||||||
|
@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
# otherwise a tuple is used.
|
# otherwise a tuple is used.
|
||||||
if num_args == 1:
|
if num_args == 1:
|
||||||
|
|
||||||
def arg_to_cache_key(arg):
|
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
else:
|
else:
|
||||||
keylist = list(keyargs)
|
keylist = list(keyargs)
|
||||||
|
|
||||||
def arg_to_cache_key(arg):
|
def arg_to_cache_key(arg: Hashable) -> Hashable:
|
||||||
keylist[self.list_pos] = arg
|
keylist[self.list_pos] = arg
|
||||||
return tuple(keylist)
|
return tuple(keylist)
|
||||||
|
|
||||||
|
@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
key = arg_to_cache_key(arg)
|
key = arg_to_cache_key(arg)
|
||||||
cache.set(key, deferred, callback=invalidate_callback)
|
cache.set(key, deferred, callback=invalidate_callback)
|
||||||
|
|
||||||
def complete_all(res):
|
def complete_all(res: Dict[Hashable, Any]) -> None:
|
||||||
# the wrapped function has completed. It returns a
|
# the wrapped function has completed. It returns a
|
||||||
# a dict. We can now resolve the observable deferreds in
|
# a dict. We can now resolve the observable deferreds in
|
||||||
# the cache and update our own result map.
|
# the cache and update our own result map.
|
||||||
|
@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
deferreds_map[e].callback(val)
|
deferreds_map[e].callback(val)
|
||||||
results[e] = val
|
results[e] = val
|
||||||
|
|
||||||
def errback(f):
|
def errback(f: Failure) -> Failure:
|
||||||
# the wrapped function has failed. Invalidate any cache
|
# the wrapped function has failed. Invalidate any cache
|
||||||
# entries we're supposed to be populating, and fail
|
# entries we're supposed to be populating, and fail
|
||||||
# their deferreds.
|
# their deferreds.
|
||||||
|
|
|
@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.config import cache as cache_config
|
from synapse.config import cache as cache_config
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
|
||||||
# Don't bother starting the loop if things never expire
|
# Don't bother starting the loop if things never expire
|
||||||
return
|
return
|
||||||
|
|
||||||
def f():
|
def f() -> "defer.Deferred[None]":
|
||||||
return run_as_background_process(
|
return run_as_background_process(
|
||||||
"prune_cache_%s" % self._cache_name, self._prune_cache
|
"prune_cache_%s" % self._cache_name, self._prune_cache
|
||||||
)
|
)
|
||||||
|
@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
class _CacheEntry:
|
class _CacheEntry:
|
||||||
time = attr.ib(type=int)
|
time: int
|
||||||
value = attr.ib()
|
value: Any
|
||||||
|
|
|
@ -18,12 +18,13 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.types import UserID
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def user_left_room(distributor, user, room_id):
|
def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
|
||||||
distributor.fire("user_left_room", user=user, room_id=room_id)
|
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +64,7 @@ class Distributor:
|
||||||
self.pre_registration[name] = []
|
self.pre_registration[name] = []
|
||||||
self.pre_registration[name].append(observer)
|
self.pre_registration[name].append(observer)
|
||||||
|
|
||||||
def fire(self, name: str, *args, **kwargs) -> None:
|
def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Dispatches the given signal to the registered observers.
|
"""Dispatches the given signal to the registered observers.
|
||||||
|
|
||||||
Runs the observers as a background process. Does not return a deferred.
|
Runs the observers as a background process. Does not return a deferred.
|
||||||
|
@ -95,7 +96,7 @@ class Signal:
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
|
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
not an error to fire a signal with no observers.
|
not an error to fire a signal with no observers.
|
||||||
|
@ -103,7 +104,7 @@ class Signal:
|
||||||
Returns a Deferred that will complete when all the observers have
|
Returns a Deferred that will complete when all the observers have
|
||||||
completed."""
|
completed."""
|
||||||
|
|
||||||
async def do(observer):
|
async def do(observer: Callable[..., Any]) -> Any:
|
||||||
try:
|
try:
|
||||||
return await maybe_awaitable(observer(*args, **kwargs))
|
return await maybe_awaitable(observer(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -120,5 +121,5 @@ class Signal:
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "<Signal name=%r>" % (self.name,)
|
return "<Signal name=%r>" % (self.name,)
|
||||||
|
|
|
@ -3,23 +3,52 @@
|
||||||
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
|
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a
|
||||||
# private class.
|
# private class.
|
||||||
|
|
||||||
|
|
||||||
from socket import (
|
from socket import (
|
||||||
AF_INET,
|
AF_INET,
|
||||||
AF_INET6,
|
AF_INET6,
|
||||||
AF_UNSPEC,
|
AF_UNSPEC,
|
||||||
SOCK_DGRAM,
|
SOCK_DGRAM,
|
||||||
SOCK_STREAM,
|
SOCK_STREAM,
|
||||||
|
AddressFamily,
|
||||||
|
SocketKind,
|
||||||
gaierror,
|
gaierror,
|
||||||
getaddrinfo,
|
getaddrinfo,
|
||||||
)
|
)
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
NoReturn,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet.address import IPv4Address, IPv6Address
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.interfaces import IHostnameResolver, IHostResolution
|
from twisted.internet.interfaces import (
|
||||||
|
IAddress,
|
||||||
|
IHostnameResolver,
|
||||||
|
IHostResolution,
|
||||||
|
IReactorThreads,
|
||||||
|
IResolutionReceiver,
|
||||||
|
)
|
||||||
from twisted.internet.threads import deferToThreadPool
|
from twisted.internet.threads import deferToThreadPool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# The types below are copied from
|
||||||
|
# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
|
||||||
|
# so that the type hints can match the interfaces.
|
||||||
|
from twisted.python.runtime import platform
|
||||||
|
|
||||||
|
if platform.supportsThreads():
|
||||||
|
from twisted.python.threadpool import ThreadPool
|
||||||
|
else:
|
||||||
|
ThreadPool = object # type: ignore[misc, assignment]
|
||||||
|
|
||||||
|
|
||||||
@implementer(IHostResolution)
|
@implementer(IHostResolution)
|
||||||
class HostResolution:
|
class HostResolution:
|
||||||
|
@ -27,13 +56,13 @@ class HostResolution:
|
||||||
The in-progress resolution of a given hostname.
|
The in-progress resolution of a given hostname.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name: str):
|
||||||
"""
|
"""
|
||||||
Create a L{HostResolution} with the given name.
|
Create a L{HostResolution} with the given name.
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self) -> NoReturn:
|
||||||
# IHostResolution.cancel
|
# IHostResolution.cancel
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -62,6 +91,17 @@ _socktypeToType = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_GETADDRINFO_RESULT = List[
|
||||||
|
Tuple[
|
||||||
|
AddressFamily,
|
||||||
|
SocketKind,
|
||||||
|
int,
|
||||||
|
str,
|
||||||
|
Union[Tuple[str, int], Tuple[str, int, int, int]],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@implementer(IHostnameResolver)
|
@implementer(IHostnameResolver)
|
||||||
class GAIResolver:
|
class GAIResolver:
|
||||||
"""
|
"""
|
||||||
|
@ -69,7 +109,12 @@ class GAIResolver:
|
||||||
L{getaddrinfo} in a thread.
|
L{getaddrinfo} in a thread.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
|
def __init__(
|
||||||
|
self,
|
||||||
|
reactor: IReactorThreads,
|
||||||
|
getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
|
||||||
|
getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create a L{GAIResolver}.
|
Create a L{GAIResolver}.
|
||||||
@param reactor: the reactor to schedule result-delivery on
|
@param reactor: the reactor to schedule result-delivery on
|
||||||
|
@ -89,14 +134,16 @@ class GAIResolver:
|
||||||
)
|
)
|
||||||
self._getaddrinfo = getaddrinfo
|
self._getaddrinfo = getaddrinfo
|
||||||
|
|
||||||
def resolveHostName(
|
# The types on IHostnameResolver is incorrect in Twisted, see
|
||||||
|
# https://twistedmatrix.com/trac/ticket/10276
|
||||||
|
def resolveHostName( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
resolutionReceiver,
|
resolutionReceiver: IResolutionReceiver,
|
||||||
hostName,
|
hostName: str,
|
||||||
portNumber=0,
|
portNumber: int = 0,
|
||||||
addressTypes=None,
|
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
|
||||||
transportSemantics="TCP",
|
transportSemantics: str = "TCP",
|
||||||
):
|
) -> IHostResolution:
|
||||||
"""
|
"""
|
||||||
See L{IHostnameResolver.resolveHostName}
|
See L{IHostnameResolver.resolveHostName}
|
||||||
@param resolutionReceiver: see interface
|
@param resolutionReceiver: see interface
|
||||||
|
@ -112,7 +159,7 @@ class GAIResolver:
|
||||||
]
|
]
|
||||||
socketType = _transportToSocket[transportSemantics]
|
socketType = _transportToSocket[transportSemantics]
|
||||||
|
|
||||||
def get():
|
def get() -> _GETADDRINFO_RESULT:
|
||||||
try:
|
try:
|
||||||
return self._getaddrinfo(
|
return self._getaddrinfo(
|
||||||
hostName, portNumber, addressFamily, socketType
|
hostName, portNumber, addressFamily, socketType
|
||||||
|
@ -125,7 +172,7 @@ class GAIResolver:
|
||||||
resolutionReceiver.resolutionBegan(resolution)
|
resolutionReceiver.resolutionBegan(resolution)
|
||||||
|
|
||||||
@d.addCallback
|
@d.addCallback
|
||||||
def deliverResults(result):
|
def deliverResults(result: _GETADDRINFO_RESULT) -> None:
|
||||||
for family, socktype, _proto, _cannoname, sockaddr in result:
|
for family, socktype, _proto, _cannoname, sockaddr in result:
|
||||||
addrType = _afToType[family]
|
addrType = _afToType[family]
|
||||||
resolutionReceiver.addressResolved(
|
resolutionReceiver.addressResolved(
|
||||||
|
|
|
@ -64,6 +64,13 @@ in_flight = InFlightGauge(
|
||||||
sub_metrics=["real_time_max", "real_time_sum"],
|
sub_metrics=["real_time_max", "real_time_sum"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This is dynamically created in InFlightGauge.__init__.
|
||||||
|
class _InFlightMetric(Protocol):
|
||||||
|
real_time_max: float
|
||||||
|
real_time_sum: float
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
@ -180,7 +187,7 @@ class Measure:
|
||||||
"""
|
"""
|
||||||
return self._logging_context.get_resource_usage()
|
return self._logging_context.get_resource_usage()
|
||||||
|
|
||||||
def _update_in_flight(self, metrics) -> None:
|
def _update_in_flight(self, metrics: _InFlightMetric) -> None:
|
||||||
"""Gets called when processing in flight metrics"""
|
"""Gets called when processing in flight metrics"""
|
||||||
assert self.start is not None
|
assert self.start is not None
|
||||||
duration = self.clock.time() - self.start
|
duration = self.clock.time() - self.start
|
||||||
|
|
Loading…
Reference in a new issue