forked from MirrorHub/synapse
Add missing type hints to test.util.caches (#14529)
This commit is contained in:
parent
7f78b383ca
commit
4ae967cf63
7 changed files with 76 additions and 66 deletions
1
changelog.d/14529.misc
Normal file
1
changelog.d/14529.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add missing type hints.
|
11
mypy.ini
11
mypy.ini
|
@ -59,11 +59,6 @@ exclude = (?x)
|
|||
|tests/server_notices/test_resource_limits_server_notices.py
|
||||
|tests/test_state.py
|
||||
|tests/test_terms_auth.py
|
||||
|tests/util/caches/test_cached_call.py
|
||||
|tests/util/caches/test_deferred_cache.py
|
||||
|tests/util/caches/test_descriptors.py
|
||||
|tests/util/caches/test_response_cache.py
|
||||
|tests/util/caches/test_ttlcache.py
|
||||
|tests/util/test_async_helpers.py
|
||||
|tests/util/test_batching_queue.py
|
||||
|tests/util/test_dict_cache.py
|
||||
|
@ -133,6 +128,12 @@ disallow_untyped_defs = True
|
|||
[mypy-tests.federation.transport.test_client]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.util.caches.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.util.caches.test_descriptors]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-tests.utils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import NoReturn
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -23,14 +24,14 @@ from tests.unittest import TestCase
|
|||
|
||||
|
||||
class CachedCallTestCase(TestCase):
|
||||
def test_get(self):
|
||||
def test_get(self) -> None:
|
||||
"""
|
||||
Happy-path test case: makes a couple of calls and makes sure they behave
|
||||
correctly
|
||||
"""
|
||||
d = Deferred()
|
||||
d: "Deferred[int]" = Deferred()
|
||||
|
||||
async def f():
|
||||
async def f() -> int:
|
||||
return await d
|
||||
|
||||
slow_call = Mock(side_effect=f)
|
||||
|
@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase):
|
|||
# now fire off a couple of calls
|
||||
completed_results = []
|
||||
|
||||
async def r():
|
||||
async def r() -> None:
|
||||
res = await cached_call.get()
|
||||
completed_results.append(res)
|
||||
|
||||
|
@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase):
|
|||
self.assertEqual(r3, 123)
|
||||
slow_call.assert_not_called()
|
||||
|
||||
def test_fast_call(self):
|
||||
def test_fast_call(self) -> None:
|
||||
"""
|
||||
Test the behaviour when the underlying function completes immediately
|
||||
"""
|
||||
|
||||
async def f():
|
||||
async def f() -> int:
|
||||
return 12
|
||||
|
||||
fast_call = Mock(side_effect=f)
|
||||
|
@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase):
|
|||
|
||||
|
||||
class RetryOnExceptionCachedCallTestCase(TestCase):
|
||||
def test_get(self):
|
||||
def test_get(self) -> None:
|
||||
# set up the RetryOnExceptionCachedCall around a function which will fail
|
||||
# (after a while)
|
||||
d = Deferred()
|
||||
d: "Deferred[int]" = Deferred()
|
||||
|
||||
async def f1():
|
||||
async def f1() -> NoReturn:
|
||||
await d
|
||||
raise ValueError("moo")
|
||||
|
||||
|
@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
|
|||
# now fire off a couple of calls
|
||||
completed_results = []
|
||||
|
||||
async def r():
|
||||
async def r() -> None:
|
||||
try:
|
||||
await cached_call.get()
|
||||
except Exception as e1:
|
||||
|
@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase):
|
|||
# to the getter
|
||||
d = Deferred()
|
||||
|
||||
async def f2():
|
||||
async def f2() -> int:
|
||||
return await d
|
||||
|
||||
slow_call.reset_mock()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -22,20 +23,20 @@ from tests.unittest import TestCase
|
|||
|
||||
|
||||
class DeferredCacheTestCase(TestCase):
|
||||
def test_empty(self):
|
||||
cache = DeferredCache("test")
|
||||
def test_empty(self) -> None:
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
with self.assertRaises(KeyError):
|
||||
cache.get("foo")
|
||||
|
||||
def test_hit(self):
|
||||
cache = DeferredCache("test")
|
||||
def test_hit(self) -> None:
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
cache.prefill("foo", 123)
|
||||
|
||||
self.assertEqual(self.successResultOf(cache.get("foo")), 123)
|
||||
|
||||
def test_hit_deferred(self):
|
||||
cache = DeferredCache("test")
|
||||
origin_d = defer.Deferred()
|
||||
def test_hit_deferred(self) -> None:
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
origin_d: "defer.Deferred[int]" = defer.Deferred()
|
||||
set_d = cache.set("k1", origin_d)
|
||||
|
||||
# get should return an incomplete deferred
|
||||
|
@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase):
|
|||
self.assertFalse(get_d.called)
|
||||
|
||||
# add a callback that will make sure that the set_d gets called before the get_d
|
||||
def check1(r):
|
||||
def check1(r: str) -> str:
|
||||
self.assertTrue(set_d.called)
|
||||
return r
|
||||
|
||||
|
@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase):
|
|||
self.assertEqual(self.successResultOf(set_d), 99)
|
||||
self.assertEqual(self.successResultOf(get_d), 99)
|
||||
|
||||
def test_callbacks(self):
|
||||
def test_callbacks(self) -> None:
|
||||
"""Invalidation callbacks are called at the right time"""
|
||||
cache = DeferredCache("test")
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
callbacks = set()
|
||||
|
||||
# start with an entry, with a callback
|
||||
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||
|
||||
# now replace that entry with a pending result
|
||||
origin_d = defer.Deferred()
|
||||
origin_d: "defer.Deferred[int]" = defer.Deferred()
|
||||
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||
|
||||
# ... and also make a get request
|
||||
|
@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase):
|
|||
cache.prefill("k1", 30)
|
||||
self.assertEqual(callbacks, {"set", "get"})
|
||||
|
||||
def test_set_fail(self):
|
||||
cache = DeferredCache("test")
|
||||
def test_set_fail(self) -> None:
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
callbacks = set()
|
||||
|
||||
# start with an entry, with a callback
|
||||
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||
|
||||
# now replace that entry with a pending result
|
||||
origin_d = defer.Deferred()
|
||||
origin_d: defer.Deferred = defer.Deferred()
|
||||
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||
|
||||
# ... and also make a get request
|
||||
|
@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase):
|
|||
cache.prefill("k1", 30)
|
||||
self.assertEqual(callbacks, {"prefill", "get2"})
|
||||
|
||||
def test_get_immediate(self):
|
||||
cache = DeferredCache("test")
|
||||
d1 = defer.Deferred()
|
||||
def test_get_immediate(self) -> None:
|
||||
cache: DeferredCache[str, int] = DeferredCache("test")
|
||||
d1: "defer.Deferred[int]" = defer.Deferred()
|
||||
cache.set("key1", d1)
|
||||
|
||||
# get_immediate should return default
|
||||
|
@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase):
|
|||
v = cache.get_immediate("key1", 1)
|
||||
self.assertEqual(v, 2)
|
||||
|
||||
def test_invalidate(self):
|
||||
cache = DeferredCache("test")
|
||||
def test_invalidate(self) -> None:
|
||||
cache: DeferredCache[Tuple[str], int] = DeferredCache("test")
|
||||
cache.prefill(("foo",), 123)
|
||||
cache.invalidate(("foo",))
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
cache.get(("foo",))
|
||||
|
||||
def test_invalidate_all(self):
|
||||
cache = DeferredCache("testcache")
|
||||
def test_invalidate_all(self) -> None:
|
||||
cache: DeferredCache[str, str] = DeferredCache("testcache")
|
||||
|
||||
callback_record = [False, False]
|
||||
|
||||
def record_callback(idx):
|
||||
def record_callback(idx: int) -> None:
|
||||
callback_record[idx] = True
|
||||
|
||||
# add a couple of pending entries
|
||||
d1 = defer.Deferred()
|
||||
d1: "defer.Deferred[str]" = defer.Deferred()
|
||||
cache.set("key1", d1, partial(record_callback, 0))
|
||||
|
||||
d2 = defer.Deferred()
|
||||
d2: "defer.Deferred[str]" = defer.Deferred()
|
||||
cache.set("key2", d2, partial(record_callback, 1))
|
||||
|
||||
# lookup should return pending deferreds
|
||||
|
@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase):
|
|||
with self.assertRaises(KeyError):
|
||||
cache.get("key1", None)
|
||||
|
||||
def test_eviction(self):
|
||||
cache = DeferredCache(
|
||||
def test_eviction(self) -> None:
|
||||
cache: DeferredCache[int, str] = DeferredCache(
|
||||
"test", max_entries=2, apply_cache_factor_from_config=False
|
||||
)
|
||||
|
||||
|
@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase):
|
|||
cache.get(2)
|
||||
cache.get(3)
|
||||
|
||||
def test_eviction_lru(self):
|
||||
cache = DeferredCache(
|
||||
def test_eviction_lru(self) -> None:
|
||||
cache: DeferredCache[int, str] = DeferredCache(
|
||||
"test", max_entries=2, apply_cache_factor_from_config=False
|
||||
)
|
||||
|
||||
|
@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase):
|
|||
cache.get(1)
|
||||
cache.get(3)
|
||||
|
||||
def test_eviction_iterable(self):
|
||||
cache = DeferredCache(
|
||||
def test_eviction_iterable(self) -> None:
|
||||
cache: DeferredCache[int, List[str]] = DeferredCache(
|
||||
"test",
|
||||
max_entries=3,
|
||||
apply_cache_factor_from_config=False,
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Iterable, Set, Tuple
|
||||
from typing import Iterable, Set, Tuple, cast
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.logging.context import (
|
||||
|
@ -37,8 +38,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def run_on_reactor():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0, d.callback, 0)
|
||||
d: "Deferred[int]" = defer.Deferred()
|
||||
cast(IReactorTime, reactor).callLater(0, d.callback, 0)
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
|
||||
|
@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
callbacks: Set[str] = set()
|
||||
|
||||
# set off an asynchronous request
|
||||
obj.result = origin_d = defer.Deferred()
|
||||
origin_d: Deferred = defer.Deferred()
|
||||
obj.result = origin_d
|
||||
|
||||
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
|
||||
self.assertFalse(d1.called)
|
||||
|
@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase):
|
|||
"""Check that logcontexts are set and restored correctly when
|
||||
using the cache."""
|
||||
|
||||
complete_lookup = defer.Deferred()
|
||||
complete_lookup: Deferred = defer.Deferred()
|
||||
|
||||
class Cls:
|
||||
@descriptors.cached()
|
||||
|
@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
|||
|
||||
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
|
||||
async def list_fn(self, args1, arg2):
|
||||
assert current_context().name == "c1"
|
||||
context = current_context()
|
||||
assert isinstance(context, LoggingContext)
|
||||
assert context.name == "c1"
|
||||
# we want this to behave like an asynchronous function
|
||||
await run_on_reactor()
|
||||
assert current_context().name == "c1"
|
||||
context = current_context()
|
||||
assert isinstance(context, LoggingContext)
|
||||
assert context.name == "c1"
|
||||
return self.mock(args1, arg2)
|
||||
|
||||
with LoggingContext("c1") as c1:
|
||||
|
@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
|||
return self.mock(args1)
|
||||
|
||||
obj = Cls()
|
||||
deferred_result = Deferred()
|
||||
deferred_result: "Deferred[dict]" = Deferred()
|
||||
obj.mock.return_value = deferred_result
|
||||
|
||||
# start off several concurrent lookups of the same key
|
||||
|
|
|
@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.reactor, self.clock = get_clock()
|
||||
|
||||
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
|
||||
|
@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
await self.clock.sleep(1)
|
||||
return o
|
||||
|
||||
def test_cache_hit(self):
|
||||
def test_cache_hit(self) -> None:
|
||||
cache = self.with_cache("keeping_cache", ms=9001)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
"cache should still have the result",
|
||||
)
|
||||
|
||||
def test_cache_miss(self):
|
||||
def test_cache_miss(self) -> None:
|
||||
cache = self.with_cache("trashing_cache", ms=0)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
)
|
||||
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
|
||||
|
||||
def test_cache_expire(self):
|
||||
def test_cache_expire(self) -> None:
|
||||
cache = self.with_cache("short_cache", ms=1000)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
self.reactor.pump((2,))
|
||||
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
|
||||
|
||||
def test_cache_wait_hit(self):
|
||||
def test_cache_wait_hit(self) -> None:
|
||||
cache = self.with_cache("neutral_cache")
|
||||
|
||||
expected_result = "howdy"
|
||||
|
@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d))
|
||||
|
||||
def test_cache_wait_expire(self):
|
||||
def test_cache_wait_expire(self) -> None:
|
||||
cache = self.with_cache("medium_cache", ms=3000)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
self.assertCountEqual([], cache.keys(), "cache should not have the result now")
|
||||
|
||||
@parameterized.expand([(True,), (False,)])
|
||||
def test_cache_context_nocache(self, should_cache: bool):
|
||||
def test_cache_context_nocache(self, should_cache: bool) -> None:
|
||||
"""If the callback clears the should_cache bit, the result should not be cached"""
|
||||
cache = self.with_cache("medium_cache", ms=3000)
|
||||
|
||||
|
@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase):
|
|||
|
||||
call_count = 0
|
||||
|
||||
async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
|
||||
async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await self.clock.sleep(1)
|
||||
|
|
|
@ -20,11 +20,11 @@ from tests import unittest
|
|||
|
||||
|
||||
class CacheTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.mock_timer = Mock(side_effect=lambda: 100.0)
|
||||
self.cache = TTLCache("test_cache", self.mock_timer)
|
||||
self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer)
|
||||
|
||||
def test_get(self):
|
||||
def test_get(self) -> None:
|
||||
"""simple set/get tests"""
|
||||
self.cache.set("one", "1", 10)
|
||||
self.cache.set("two", "2", 20)
|
||||
|
@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase):
|
|||
self.assertEqual(self.cache._metrics.hits, 4)
|
||||
self.assertEqual(self.cache._metrics.misses, 5)
|
||||
|
||||
def test_expiry(self):
|
||||
def test_expiry(self) -> None:
|
||||
self.cache.set("one", "1", 10)
|
||||
self.cache.set("two", "2", 20)
|
||||
self.cache.set("three", "3", 30)
|
||||
|
|
Loading…
Reference in a new issue