0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 14:32:30 +01:00

Add type hints to test_descriptors. (#15659)

Require type hints in test_descriptors and add missing ones.
This commit is contained in:
Patrick Cloke 2023-05-24 10:18:52 -04:00 committed by GitHub
parent c7e9c1d5ae
commit ca5c4be921
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 96 deletions

1
changelog.d/15659.misc Normal file
View file

@ -0,0 +1 @@
Improve type hints.

View file

@ -32,9 +32,6 @@ warn_unused_ignores = False
[mypy-synapse.util.caches.treecache] [mypy-synapse.util.caches.treecache]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
;; Dependencies without annotations ;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available. ;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here: ;; The `typeshed` project maintains stubs here:

View file

@ -13,7 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Set, Tuple, cast from typing import (
Any,
Dict,
Generator,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
cast,
)
from unittest import mock from unittest import mock
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -29,7 +40,7 @@ from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
) )
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from tests import unittest from tests import unittest
from tests.test_utils import get_awaitable_result from tests.test_utils import get_awaitable_result
@ -37,21 +48,21 @@ from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_on_reactor(): def run_on_reactor() -> "Deferred[int]":
d: "Deferred[int]" = defer.Deferred() d: "Deferred[int]" = Deferred()
cast(IReactorTime, reactor).callLater(0, d.callback, 0) cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self) -> Generator["Deferred[Any]", object, None]:
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1, arg2): def fn(self, arg1: int, arg2: int) -> str:
return self.mock(arg1, arg2) return self.mock(arg1, arg2)
obj = Cls() obj = Cls()
@ -77,15 +88,15 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called() obj.mock.assert_not_called()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache_num_args(self): def test_cache_num_args(self) -> Generator["Deferred[Any]", object, None]:
"""Only the first num_args arguments should matter to the cache""" """Only the first num_args arguments should matter to the cache"""
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached(num_args=1) @descriptors.cached(num_args=1)
def fn(self, arg1, arg2): def fn(self, arg1: int, arg2: int) -> mock.Mock:
return self.mock(arg1, arg2) return self.mock(arg1, arg2)
obj = Cls() obj = Cls()
@ -111,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called() obj.mock.assert_not_called()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache_uncached_args(self): def test_cache_uncached_args(self) -> Generator["Deferred[Any]", object, None]:
""" """
Only the arguments not named in uncached_args should matter to the cache Only the arguments not named in uncached_args should matter to the cache
@ -123,10 +134,10 @@ class DescriptorTestCase(unittest.TestCase):
# Note that it is important that this is not the last argument to # Note that it is important that this is not the last argument to
# test behaviour of skipping arguments properly. # test behaviour of skipping arguments properly.
@descriptors.cached(uncached_args=("arg2",)) @descriptors.cached(uncached_args=("arg2",))
def fn(self, arg1, arg2, arg3): def fn(self, arg1: int, arg2: int, arg3: int) -> str:
return self.mock(arg1, arg2, arg3) return self.mock(arg1, arg2, arg3)
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
obj = Cls() obj = Cls()
@ -152,15 +163,15 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called() obj.mock.assert_not_called()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache_kwargs(self): def test_cache_kwargs(self) -> Generator["Deferred[Any]", object, None]:
"""Test that keyword arguments are treated properly""" """Test that keyword arguments are treated properly"""
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1, kwarg1=2): def fn(self, arg1: int, kwarg1: int = 2) -> str:
return self.mock(arg1, kwarg1=kwarg1) return self.mock(arg1, kwarg1=kwarg1)
obj = Cls() obj = Cls()
@ -188,12 +199,12 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "fish") self.assertEqual(r, "fish")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_with_sync_exception(self): def test_cache_with_sync_exception(self) -> None:
"""If the wrapped function throws synchronously, things should continue to work""" """If the wrapped function throws synchronously, things should continue to work"""
class Cls: class Cls:
@cached() @cached()
def fn(self, arg1): def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1") raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls() obj = Cls()
@ -209,15 +220,15 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1) d = obj.fn(1)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
def test_cache_with_async_exception(self): def test_cache_with_async_exception(self) -> None:
"""The wrapped function returns a failure""" """The wrapped function returns a failure"""
class Cls: class Cls:
result = None result: Optional[Deferred] = None
call_count = 0 call_count = 0
@cached() @cached()
def fn(self, arg1): def fn(self, arg1: int) -> Optional[Deferred]:
self.call_count += 1 self.call_count += 1
return self.result return self.result
@ -225,7 +236,7 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set() callbacks: Set[str] = set()
# set off an asynchronous request # set off an asynchronous request
origin_d: Deferred = defer.Deferred() origin_d: Deferred = Deferred()
obj.result = origin_d obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
@ -260,17 +271,17 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(self.successResultOf(d3), 100) self.assertEqual(self.successResultOf(d3), 100)
self.assertEqual(obj.call_count, 2) self.assertEqual(obj.call_count, 2)
def test_cache_logcontexts(self): def test_cache_logcontexts(self) -> Deferred:
"""Check that logcontexts are set and restored correctly when """Check that logcontexts are set and restored correctly when
using the cache.""" using the cache."""
complete_lookup: Deferred = defer.Deferred() complete_lookup: Deferred = Deferred()
class Cls: class Cls:
@descriptors.cached() @descriptors.cached()
def fn(self, arg1): def fn(self, arg1: int) -> "Deferred[int]":
@defer.inlineCallbacks @defer.inlineCallbacks
def inner_fn(): def inner_fn() -> Generator["Deferred[object]", object, int]:
with PreserveLoggingContext(): with PreserveLoggingContext():
yield complete_lookup yield complete_lookup
return 1 return 1
@ -278,13 +289,13 @@ class DescriptorTestCase(unittest.TestCase):
return inner_fn() return inner_fn()
@defer.inlineCallbacks @defer.inlineCallbacks
def do_lookup(): def do_lookup() -> Generator["Deferred[Any]", object, int]:
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
r = yield obj.fn(1) r = yield obj.fn(1)
self.assertEqual(current_context(), c1) self.assertEqual(current_context(), c1)
return r return cast(int, r)
def check_result(r): def check_result(r: int) -> None:
self.assertEqual(r, 1) self.assertEqual(r, 1)
obj = Cls() obj = Cls()
@ -304,15 +315,15 @@ class DescriptorTestCase(unittest.TestCase):
return defer.gatherResults([d1, d2]) return defer.gatherResults([d1, d2])
def test_cache_logcontexts_with_exception(self): def test_cache_logcontexts_with_exception(self) -> "Deferred[None]":
"""Check that the cache sets and restores logcontexts correctly when """Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception""" the lookup function throws an exception"""
class Cls: class Cls:
@descriptors.cached() @descriptors.cached()
def fn(self, arg1): def fn(self, arg1: int) -> Deferred:
@defer.inlineCallbacks @defer.inlineCallbacks
def inner_fn(): def inner_fn() -> Generator["Deferred[Any]", object, NoReturn]:
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
yield run_on_reactor() yield run_on_reactor()
raise SynapseError(400, "blah") raise SynapseError(400, "blah")
@ -320,7 +331,7 @@ class DescriptorTestCase(unittest.TestCase):
return inner_fn() return inner_fn()
@defer.inlineCallbacks @defer.inlineCallbacks
def do_lookup(): def do_lookup() -> Generator["Deferred[object]", object, None]:
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
try: try:
d = obj.fn(1) d = obj.fn(1)
@ -347,13 +358,13 @@ class DescriptorTestCase(unittest.TestCase):
return d1 return d1
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache_default_args(self): def test_cache_default_args(self) -> Generator["Deferred[Any]", object, None]:
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1, arg2=2, arg3=3): def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str:
return self.mock(arg1, arg2, arg3) return self.mock(arg1, arg2, arg3)
obj = Cls() obj = Cls()
@ -384,13 +395,13 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips") self.assertEqual(r, "chips")
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_iterable(self): def test_cache_iterable(self) -> None:
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached(iterable=True) @descriptors.cached(iterable=True)
def fn(self, arg1, arg2): def fn(self, arg1: int, arg2: int) -> List[str]:
return self.mock(arg1, arg2) return self.mock(arg1, arg2)
obj = Cls() obj = Cls()
@ -417,12 +428,12 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r.result, ["chips"]) self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called() obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self): def test_cache_iterable_with_sync_exception(self) -> None:
"""If the wrapped function throws synchronously, things should continue to work""" """If the wrapped function throws synchronously, things should continue to work"""
class Cls: class Cls:
@descriptors.cached(iterable=True) @descriptors.cached(iterable=True)
def fn(self, arg1): def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1") raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls() obj = Cls()
@ -438,20 +449,20 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1) d = obj.fn(1)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
def test_invalidate_cascade(self): def test_invalidate_cascade(self) -> None:
"""Invalidations should cascade up through cache contexts""" """Invalidations should cascade up through cache contexts"""
class Cls: class Cls:
@cached(cache_context=True) @cached(cache_context=True)
async def func1(self, key, cache_context): async def func1(self, key: str, cache_context: _CacheContext) -> int:
return await self.func2(key, on_invalidate=cache_context.invalidate) return await self.func2(key, on_invalidate=cache_context.invalidate)
@cached(cache_context=True) @cached(cache_context=True)
async def func2(self, key, cache_context): async def func2(self, key: str, cache_context: _CacheContext) -> int:
return await self.func3(key, on_invalidate=cache_context.invalidate) return await self.func3(key, on_invalidate=cache_context.invalidate)
@cached(cache_context=True) @cached(cache_context=True)
async def func3(self, key, cache_context): async def func3(self, key: str, cache_context: _CacheContext) -> int:
self.invalidate = cache_context.invalidate self.invalidate = cache_context.invalidate
return 42 return 42
@ -463,13 +474,13 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate() obj.invalidate()
top_invalidate.assert_called_once() top_invalidate.assert_called_once()
def test_cancel(self): def test_cancel(self) -> None:
"""Test that cancelling a lookup does not cancel other lookups""" """Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred() complete_lookup: "Deferred[None]" = Deferred()
class Cls: class Cls:
@cached() @cached()
async def fn(self, arg1): async def fn(self, arg1: int) -> str:
await complete_lookup await complete_lookup
return str(arg1) return str(arg1)
@ -488,7 +499,7 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d1, CancelledError) self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, "123") self.assertEqual(d2.result, "123")
def test_cancel_logcontexts(self): def test_cancel_logcontexts(self) -> None:
"""Test that cancellation does not break logcontexts. """Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext. * The `CancelledError` must be raised with the correct logcontext.
@ -501,14 +512,14 @@ class DescriptorTestCase(unittest.TestCase):
inner_context_was_finished = False inner_context_was_finished = False
@cached() @cached()
async def fn(self, arg1): async def fn(self, arg1: int) -> str:
await make_deferred_yieldable(complete_lookup) await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished self.inner_context_was_finished = current_context().finished
return str(arg1) return str(arg1)
obj = Cls() obj = Cls()
async def do_lookup(): async def do_lookup() -> None:
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
try: try:
await obj.fn(123) await obj.fn(123)
@ -542,10 +553,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self) -> Generator["Deferred[Any]", object, None]:
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
return key return key
a = A() a = A()
@ -554,12 +565,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual((yield a.func("bar")), "bar") self.assertEqual((yield a.func("bar")), "bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hit(self): def test_hit(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
return key return key
@ -572,12 +583,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 1) self.assertEqual(callcount[0], 1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
return key return key
@ -592,21 +603,21 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 2) self.assertEqual(callcount[0], 2)
def test_invalidate_missing(self): def test_invalidate_missing(self) -> None:
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
return key return key
A().func.invalidate(("what",)) A().func.invalidate(("what",))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
class A: class A:
@cached(max_entries=10) @cached(max_entries=10)
def func(self, key): def func(self, key: int) -> int:
callcount[0] += 1 callcount[0] += 1
return key return key
@ -626,14 +637,14 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
) )
def test_prefill(self): def test_prefill(self) -> None:
callcount = [0] callcount = [0]
d = defer.succeed(123) d = defer.succeed(123)
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> "Deferred[int]":
callcount[0] += 1 callcount[0] += 1
return d return d
@ -645,18 +656,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 0) self.assertEqual(callcount[0], 0)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate_context(self): def test_invalidate_context(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
callcount2 = [0] callcount2 = [0]
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
return key return key
@cached(cache_context=True) @cached(cache_context=True)
def func2(self, key, cache_context): def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1 callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate) return self.func(key, on_invalidate=cache_context.invalidate)
@ -678,18 +689,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount2[0], 2) self.assertEqual(callcount2[0], 2)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_eviction_context(self): def test_eviction_context(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
callcount2 = [0] callcount2 = [0]
class A: class A:
@cached(max_entries=2) @cached(max_entries=2)
def func(self, key): def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
return key return key
@cached(cache_context=True) @cached(cache_context=True)
def func2(self, key, cache_context): def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1 callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate) return self.func(key, on_invalidate=cache_context.invalidate)
@ -715,18 +726,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount2[0], 3) self.assertEqual(callcount2[0], 3)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_double_get(self): def test_double_get(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0] callcount = [0]
callcount2 = [0] callcount2 = [0]
class A: class A:
@cached() @cached()
def func(self, key): def func(self, key: str) -> str:
callcount[0] += 1 callcount[0] += 1
return key return key
@cached(cache_context=True) @cached(cache_context=True)
def func2(self, key, cache_context): def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1 callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate) return self.func(key, on_invalidate=cache_context.invalidate)
@ -763,17 +774,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class CachedListDescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self) -> Generator["Deferred[Any]", object, None]:
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1, arg2): def fn(self, arg1: int, arg2: int) -> None:
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2): async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
context = current_context() context = current_context()
assert isinstance(context, LoggingContext) assert isinstance(context, LoggingContext)
assert context.name == "c1" assert context.name == "c1"
@ -824,19 +835,19 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_called_once_with({40}, 2) obj.mock.assert_called_once_with({40}, 2)
self.assertEqual(r, {10: "fish", 40: "gravy"}) self.assertEqual(r, {10: "fish", 40: "gravy"})
def test_concurrent_lookups(self): def test_concurrent_lookups(self) -> None:
"""All concurrent lookups should get the same result""" """All concurrent lookups should get the same result"""
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1): def fn(self, arg1: int) -> None:
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1) -> "Deferred[dict]": def list_fn(self, args1: List[int]) -> "Deferred[dict]":
return self.mock(args1) return self.mock(args1)
obj = Cls() obj = Cls()
@ -867,19 +878,19 @@ class CachedListDescriptorTestCase(unittest.TestCase):
self.assertEqual(self.successResultOf(d3), {10: "peas"}) self.assertEqual(self.successResultOf(d3), {10: "peas"})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
"""Make sure that invalidation callbacks are called.""" """Make sure that invalidation callbacks are called."""
class Cls: class Cls:
def __init__(self): def __init__(self) -> None:
self.mock = mock.Mock() self.mock = mock.Mock()
@descriptors.cached() @descriptors.cached()
def fn(self, arg1, arg2): def fn(self, arg1: int, arg2: int) -> None:
pass pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1") @descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2): async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
await run_on_reactor() await run_on_reactor()
return self.mock(args1, arg2) return self.mock(args1, arg2)
@ -908,17 +919,17 @@ class CachedListDescriptorTestCase(unittest.TestCase):
invalidate0.assert_called_once() invalidate0.assert_called_once()
invalidate1.assert_called_once() invalidate1.assert_called_once()
def test_cancel(self): def test_cancel(self) -> None:
"""Test that cancelling a lookup does not cancel other lookups""" """Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred() complete_lookup: "Deferred[None]" = Deferred()
class Cls: class Cls:
@cached() @cached()
def fn(self, arg1): def fn(self, arg1: int) -> None:
pass pass
@cachedList(cached_method_name="fn", list_name="args") @cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args): async def list_fn(self, args: List[int]) -> Dict[int, str]:
await complete_lookup await complete_lookup
return {arg: str(arg) for arg in args} return {arg: str(arg) for arg in args}
@ -936,7 +947,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
self.failureResultOf(d1, CancelledError) self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
def test_cancel_logcontexts(self): def test_cancel_logcontexts(self) -> None:
"""Test that cancellation does not break logcontexts. """Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext. * The `CancelledError` must be raised with the correct logcontext.
@ -949,18 +960,18 @@ class CachedListDescriptorTestCase(unittest.TestCase):
inner_context_was_finished = False inner_context_was_finished = False
@cached() @cached()
def fn(self, arg1): def fn(self, arg1: int) -> None:
pass pass
@cachedList(cached_method_name="fn", list_name="args") @cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args): async def list_fn(self, args: List[int]) -> Dict[int, str]:
await make_deferred_yieldable(complete_lookup) await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args} return {arg: str(arg) for arg in args}
obj = Cls() obj = Cls()
async def do_lookup(): async def do_lookup() -> None:
with LoggingContext("c1") as c1: with LoggingContext("c1") as c1:
try: try:
await obj.list_fn([123]) await obj.list_fn([123])
@ -983,7 +994,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
) )
self.assertEqual(current_context(), SENTINEL_CONTEXT) self.assertEqual(current_context(), SENTINEL_CONTEXT)
def test_num_args_mismatch(self): def test_num_args_mismatch(self) -> None:
""" """
Make sure someone does not accidentally use @cachedList on a method with Make sure someone does not accidentally use @cachedList on a method with
a mismatch in the number args to the underlying single cache method. a mismatch in the number args to the underlying single cache method.
@ -991,14 +1002,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls: class Cls:
@descriptors.cached(tree=True) @descriptors.cached(tree=True)
def fn(self, room_id, event_id): def fn(self, room_id: str, event_id: str) -> None:
pass pass
# This is wrong ❌. `@cachedList` expects to be given the same number # This is wrong ❌. `@cachedList` expects to be given the same number
# of arguments as the underlying cached function, just with one of # of arguments as the underlying cached function, just with one of
# the arguments being an iterable # the arguments being an iterable
@descriptors.cachedList(cached_method_name="fn", list_name="keys") @descriptors.cachedList(cached_method_name="fn", list_name="keys")
def list_fn(self, keys: Iterable[Tuple[str, str]]): def list_fn(self, keys: Iterable[Tuple[str, str]]) -> None:
pass pass
# Corrected syntax ✅ # Corrected syntax ✅