0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-30 15:34:02 +01:00

Add missing types to tests.util. (#14597)

Removes files under tests.util from the ignored by list, then
fully types all tests/util/*.py files.
This commit is contained in:
Patrick Cloke 2022-12-02 12:58:56 -05:00 committed by GitHub
parent fac8a38525
commit acea4d7a2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 361 additions and 276 deletions

1
changelog.d/14597.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -59,16 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
|tests/util/test_expiring_cache.py
|tests/util/test_file_consumer.py
|tests/util/test_linearizer.py
|tests/util/test_logcontext.py
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
)$
[mypy-synapse.federation.transport.client]
@ -137,6 +127,9 @@ disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
[mypy-tests.util.*]
disallow_untyped_defs = True
[mypy-tests.utils]
disallow_untyped_defs = True

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Generator, List, NoReturn, Optional
from parameterized import parameterized_class
@ -41,8 +42,8 @@ from tests.unittest import TestCase
class ObservableDeferredTest(TestCase):
def test_succeed(self):
origin_d = Deferred()
def test_succeed(self) -> None:
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d)
observer1 = observable.observe()
@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res):
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results = [None, None]
results: List[Optional[ObservableDeferred[int]]] = [None, None]
def check_val(res, idx):
def check_val(
res: ObservableDeferred[int], idx: int
) -> ObservableDeferred[int]:
results[idx] = res
return res
@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")
def test_failure(self):
origin_d = Deferred()
def test_failure(self) -> None:
origin_d: Deferred = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res):
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results = [None, None]
results: List[Optional[ObservableDeferred[str]]] = [None, None]
def check_val(res, idx):
def check_val(res: ObservableDeferred[str], idx: int) -> None:
results[idx] = res
return None
@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
def test_cancellation(self):
def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase):
def setUp(self):
def setUp(self) -> None:
self.clock = Clock()
def test_times_out(self):
def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
cancelled = [False]
cancelled = False
def canceller(_d):
cancelled[0] = True
def canceller(_d: Deferred) -> None:
nonlocal cancelled
cancelled = True
non_completing_d = Deferred(canceller)
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,))
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_times_out_when_canceller_throws(self):
def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
def canceller(_d):
def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred")
non_completing_d = Deferred(canceller)
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_logcontext_is_preserved_on_cancellation(self):
blocking_was_cancelled = [False]
def test_logcontext_is_preserved_on_cancellation(self) -> None:
blocking_was_cancelled = False
@defer.inlineCallbacks
def blocking():
non_completing_d = Deferred()
def blocking() -> Generator["Deferred[object]", object, None]:
nonlocal blocking_was_cancelled
non_completing_d: Deferred = Deferred()
with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled[0] = True
blocking_was_cancelled = True
raise
with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
def errback(res: Failure, deferred_name: str) -> Failure:
self.assertIs(
current_context(),
context_one,
@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
self.clock.pump((1.0,))
self.assertTrue(
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
blocking_was_cancelled, "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
@ -220,13 +228,13 @@ class _TestException(Exception):
class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self):
def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
async def callback(v):
async def callback(v: int) -> None:
# when we first enter, bump the start count
nonlocal started
started += 1
@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
processed.append(v)
# wait for the goahead before returning
d2 = Deferred()
d2: "Deferred[int]" = Deferred()
waiters.append(d2)
await d2
@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
def test_preserves_stacktraces(self):
def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred()
d1: "Deferred[int]" = Deferred()
async def callback(v):
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
async def caller():
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
d1.callback(0)
self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self):
def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred()
d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah"))
async def callback(v):
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
async def caller():
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self):
def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
def test_failure(self):
def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
def test_cancellation(self):
def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
def test_deferred_cancellation(self):
def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self):
def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
async def task():
async def task() -> NoReturn:
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self):
def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_propagates_cancelled_error(self):
def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
def test_preserves_logcontext(self):
def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()
async def inner():
async def inner() -> None:
await make_deferred_yieldable(blocking_d)
async def outer():
async def outer() -> None:
with LoggingContext("c") as c:
try:
await delay_cancellation(inner())
@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
def test_sleep(self):
def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d.called)
def test_explicit_wake(self):
def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
def test_multiple_sleepers_timeout(self):
def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d2.called)
def test_multiple_sleepers_wake(self):
def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

View file

@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable
@ -26,7 +30,7 @@ from tests.unittest import TestCase
class BatchingQueueTestCase(TestCase):
def setUp(self):
def setUp(self) -> None:
self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue".
@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
except KeyError:
pass
self._pending_calls = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
self.queue: BatchingQueue[str, str] = BatchingQueue(
"test_queue", hs_clock, self._process_queue
)
async def _process_queue(self, values):
d = defer.Deferred()
async def _process_queue(self, values: List[str]) -> str:
d: "defer.Deferred[str]" = defer.Deferred()
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)
def _get_sample_with_name(self, metric, name) -> int:
def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
"""For a prometheus metric get the value of the sample that has a
matching "name" label.
"""
for sample in metric.collect()[0].samples:
for sample in next(iter(metric.collect())).samples:
if sample.labels.get("name") == name:
return sample.value
self.fail("Found no matching sample")
def _assert_metrics(self, queued, keys, in_flight):
def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
"""Assert that the metrics are correct"""
sample = self._get_sample_with_name(number_queued, self.queue._name)
@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
"number_in_flight",
)
def test_simple(self):
def test_simple(self) -> None:
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
"""
@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_batching(self):
def test_batching(self) -> None:
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
"""
@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_queuing(self):
def test_queuing(self) -> None:
"""Test that we queue up requests while a `_process_queue` is being
called.
"""
@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_different_keys(self):
def test_different_keys(self) -> None:
"""Test that calls to different keys get processed in parallel."""
self.assertFalse(self._pending_calls)

View file

@ -1,5 +1,20 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator, Optional
from os import PathLike
from typing import Generator, Optional, Union
from unittest.mock import patch
from synapse.util.check_dependencies import (
@ -12,17 +27,17 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
def __init__(self, version: object):
def __init__(self, version: str):
self._version = version
@property
def version(self):
def version(self) -> str:
return self._version
def locate_file(self, path):
def locate_file(self, path: Union[str, PathLike]) -> PathLike:
raise NotImplementedError()
def read_text(self, filename):
def read_text(self, filename: str) -> None:
raise NotImplementedError()
@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
distribution_with_no_version = DummyDistribution(None)
distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
# could probably use stdlib TestCase --- no need for twisted here
@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
If `distribution = None`, we pretend that the package is not installed.
"""
def mock_distribution(name: str):
def mock_distribution(name: str) -> DummyDistribution:
if distribution is None:
raise metadata.PackageNotFoundError
else:

View file

@ -19,10 +19,12 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
def setUp(self):
self.cache = DictionaryCache("foobar", max_entries=10)
def setUp(self) -> None:
self.cache: DictionaryCache[str, str, str] = DictionaryCache(
"foobar", max_entries=10
)
def test_simple_cache_hit_full(self):
def test_simple_cache_hit_full(self) -> None:
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
def test_simple_cache_hit_partial(self):
def test_simple_cache_hit_partial(self) -> None:
key = "test_simple_cache_hit_partial"
seq = self.cache.sequence
@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
def test_simple_cache_miss_partial(self):
def test_simple_cache_miss_partial(self) -> None:
key = "test_simple_cache_miss_partial"
seq = self.cache.sequence
@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
def test_simple_cache_hit_miss_partial(self):
def test_simple_cache_hit_miss_partial(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
def test_multi_insert(self):
def test_multi_insert(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
)
self.assertEqual(c.full, False)
def test_invalidation(self):
def test_invalidation(self) -> None:
"""Test that the partial dict and full dicts get invalidated
separately.
"""
@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
# entry for "a" warm.
for i in range(20):
self.cache.get(key, ["a"])
self.cache.update(seq, f"key{i}", {1: 2})
self.cache.update(seq, f"key{i}", {"1": "2"})
# We should have evicted the full dict...
r = self.cache.get(key)

View file

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, cast
from synapse.util import Clock
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
@ -21,17 +23,21 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
def test_get_set(self) -> None:
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=1
)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self):
def test_eviction(self) -> None:
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2)
cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=2
)
cache["key"] = "value"
cache["key2"] = "value2"
@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3")
def test_iterable_eviction(self):
def test_iterable_eviction(self) -> None:
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache: ExpiringCache[str, List[int]] = ExpiringCache(
"test", cast(Clock, clock), max_len=5, iterable=True
)
cache["key"] = [1]
cache["key2"] = [2, 3]
@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), [4, 5])
self.assertEqual(cache.get("key4"), [6, 7])
def test_time_eviction(self):
def test_time_eviction(self) -> None:
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache: ExpiringCache[str, int] = ExpiringCache(
"test", cast(Clock, clock), expiry_ms=1000
)
cache["key"] = 1
clock.advance_time(0.5)

View file

@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from io import StringIO
from io import BytesIO
from typing import BinaryIO, Generator, Optional, cast
from unittest.mock import NonCallableMock
from twisted.internet import defer, reactor
from zope.interface import implementer
from twisted.internet import defer, reactor as _reactor
from twisted.internet.interfaces import IPullProducer
from synapse.types import ISynapseReactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
reactor = cast(ISynapseReactor, _reactor)
class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = BytesIO()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
yield producer.register_with_consumer(consumer)
yield producer.write_and_wait("Foo")
yield producer.write_and_wait(b"Foo")
self.assertEqual(string_file.getvalue(), "Foo")
self.assertEqual(string_file.getvalue(), b"Foo")
yield producer.write_and_wait("Bar")
yield producer.write_and_wait(b"Bar")
self.assertEqual(string_file.getvalue(), "FooBar")
self.assertEqual(string_file.getvalue(), b"FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait()
yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True)
consumer.write("Foo")
yield string_file.wait_for_n_writes(1)
consumer.write(b"Foo")
yield string_file.wait_for_n_writes(1) # type: ignore[misc]
self.assertEqual(string_file.buffer, "Foo")
self.assertEqual(string_file.buffer, b"Foo")
consumer.write("Bar")
yield string_file.wait_for_n_writes(2)
consumer.write(b"Bar")
yield string_file.wait_for_n_writes(2) # type: ignore[misc]
self.assertEqual(string_file.buffer, "FooBar")
self.assertEqual(string_file.buffer, b"FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait()
yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
def test_push_producer_feedback(
self,
) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred = defer.Deferred()
resume_deferred: defer.Deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None
)
@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
number_writes = 0
with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
consumer.write("Foo")
consumer.write(b"Foo")
number_writes += 1
producer.pauseProducing.assert_called_once()
yield string_file.wait_for_n_writes(number_writes)
yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
yield resume_deferred
producer.resumeProducing.assert_called_once()
finally:
consumer.unregisterProducer()
yield consumer.wait()
yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@implementer(IPullProducer)
class DummyPullProducer:
def __init__(self):
self.consumer = None
self.deferred = defer.Deferred()
def __init__(self) -> None:
self.consumer: Optional[BackgroundFileConsumer] = None
self.deferred: "defer.Deferred[object]" = defer.Deferred()
def resumeProducing(self):
def resumeProducing(self) -> None:
d = self.deferred
self.deferred = defer.Deferred()
d.callback(None)
def write_and_wait(self, bytes):
def stopProducing(self) -> None:
raise RuntimeError("Unexpected call")
def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
assert self.consumer is not None
d = self.deferred
self.consumer.write(bytes)
self.consumer.write(write_bytes)
return d
def register_with_consumer(self, consumer):
def register_with_consumer(
self, consumer: BackgroundFileConsumer
) -> "defer.Deferred[object]":
d = self.deferred
self.consumer = consumer
self.consumer.registerProducer(self, False)
return d
class BlockingStringWrite:
def __init__(self):
self.buffer = ""
class BlockingBytesWrite:
def __init__(self) -> None:
self.buffer = b""
self.closed = False
self.write_lock = threading.Lock()
self._notify_write_deferred = None
self._notify_write_deferred: Optional[defer.Deferred] = None
self._number_of_writes = 0
def write(self, bytes):
def write(self, write_bytes: bytes) -> None:
with self.write_lock:
self.buffer += bytes
self.buffer += write_bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
def close(self):
def close(self) -> None:
self.closed = True
def _notify_write(self):
def _notify_write(self) -> None:
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
@ -161,7 +176,9 @@ class BlockingStringWrite:
d.callback(None)
@defer.inlineCallbacks
def wait_for_n_writes(self, n):
def wait_for_n_writes(
self, n: int
) -> Generator["defer.Deferred[object]", object, None]:
"Wait for n writes to have happened"
while True:
with self.write_lock:

View file

@ -19,7 +19,7 @@ from tests.unittest import TestCase
class ChunkSeqTests(TestCase):
def test_short_seq(self):
def test_short_seq(self) -> None:
parts = chunk_seq("123", 8)
self.assertEqual(
@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
["123"],
)
def test_long_seq(self):
def test_long_seq(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
["abcdefgh", "ijklmnop"],
)
def test_uneven_parts(self):
def test_uneven_parts(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
["abcde", "fghij", "klmno", "p"],
)
def test_empty_input(self):
def test_empty_input(self) -> None:
parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual(
@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
class SortTopologically(TestCase):
def test_empty(self):
def test_empty(self) -> None:
"Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), [])
def test_handle_empty_graph(self):
def test_handle_empty_graph(self) -> None:
"Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {}
@ -67,7 +67,7 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_disconnected(self):
def test_disconnected(self) -> None:
"Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []}
@ -75,20 +75,20 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_linear(self):
def test_linear(self) -> None:
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_subset(self):
def test_subset(self) -> None:
"Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
def test_fork(self):
def test_fork(self) -> None:
"Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
@ -96,13 +96,13 @@ class SortTopologically(TestCase):
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_duplicates(self):
def test_duplicates(self) -> None:
"Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self):
def test_multiple_paths(self) -> None:
"Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}

View file

@ -1,5 +1,21 @@
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Generator, cast
import twisted.python.failure
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as _reactor
from synapse.logging.context import (
SENTINEL_CONTEXT,
@ -10,25 +26,30 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
from synapse.types import ISynapseReactor
from synapse.util import Clock
from .. import unittest
reactor = cast(ISynapseReactor, _reactor)
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
self.assertEqual(current_context().name, value)
def _check_test_key(self, value: str) -> None:
context = current_context()
assert isinstance(context, LoggingContext)
self.assertEqual(context.name, value)
def test_with_context(self):
def test_with_context(self) -> None:
with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
def test_sleep(self):
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
clock = Clock(reactor)
@defer.inlineCallbacks
def competing_callback():
def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
sentinel_context = current_context()
callback_completed = [False]
callback_completed = False
with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
def cb(res):
callback_completed[0] = True
def cb(res: object) -> object:
nonlocal callback_completed
callback_completed = True
return res
d2.addCallback(cb)
@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
# the logcontext is left in a sane state.
d2 = defer.Deferred()
def check_logcontext():
if not callback_completed[0]:
def check_logcontext() -> None:
if not callback_completed:
reactor.callLater(0.01, check_logcontext)
return
@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes
return d2
def test_run_in_background_with_blocking_fn(self):
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
def blocking_function():
def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
def test_run_in_background_with_non_blocking_fn(self):
def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
def nonblocking_function():
def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
def test_run_in_background_with_chained_deferred(self):
def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
# a function which returns a deferred which looks like it has been
# called, but is actually paused
def testfunc():
def testfunc() -> defer.Deferred:
return make_deferred_yieldable(_chained_deferred_function())
return self._test_run_in_background(testfunc)
def test_run_in_background_with_coroutine(self):
async def testfunc():
def test_run_in_background_with_coroutine(self) -> defer.Deferred:
async def testfunc() -> None:
self._check_test_key("one")
d = Clock(reactor).sleep(0)
self.assertIs(current_context(), SENTINEL_CONTEXT)
@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
return self._test_run_in_background(testfunc)
def test_run_in_background_with_nonblocking_coroutine(self):
async def testfunc():
def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
async def testfunc() -> None:
self._check_test_key("one")
return self._test_run_in_background(testfunc)
@defer.inlineCallbacks
def test_make_deferred_yieldable(self):
def test_make_deferred_yieldable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
# a function which returns an incomplete deferred, but doesn't follow
# the synapse rules.
def blocking_function():
d = defer.Deferred()
def blocking_function() -> defer.Deferred:
d: defer.Deferred = defer.Deferred()
reactor.callLater(0, d.callback, None)
return d
@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self):
def test_make_deferred_yieldable_with_chained_deferreds(
self,
) -> Generator["defer.Deferred[object]", object, None]:
sentinel_context = current_context()
with LoggingContext("one"):
@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored
self._check_test_key("one")
def test_nested_logging_context(self):
def test_nested_logging_context(self) -> None:
with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")
@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
# its callback list, so won't yet call any other new callbacks.
def _chained_deferred_function():
def _chained_deferred_function() -> defer.Deferred:
d = defer.succeed(None)
def cb(res):
d2 = defer.Deferred()
def cb(res: object) -> defer.Deferred:
d2: defer.Deferred = defer.Deferred()
reactor.callLater(0, d2.callback, res)
return d2

View file

@ -23,7 +23,7 @@ class TestException(Exception):
class LogFormatterTestCase(unittest.TestCase):
def test_formatter(self):
def test_formatter(self) -> None:
formatter = LogFormatter()
try:

View file

@ -13,10 +13,11 @@
# limitations under the License.
from typing import List
from typing import List, Tuple
from unittest.mock import Mock, patch
from synapse.metrics.jemalloc import JemallocStats
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@ -25,14 +26,14 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
cache = LruCache(1)
def test_get_set(self) -> None:
cache: LruCache[str, str] = LruCache(1)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self):
cache = LruCache(2)
def test_eviction(self) -> None:
cache: LruCache[int, int] = LruCache(2)
cache[1] = 1
cache[2] = 2
@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(2), 2)
self.assertEqual(cache.get(3), 3)
def test_setdefault(self):
cache = LruCache(1)
def test_setdefault(self) -> None:
cache: LruCache[str, int] = LruCache(1)
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache["key"] = 2 # Make sure overriding works.
self.assertEqual(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
def test_pop(self) -> None:
cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
def test_del_multi(self):
cache = LruCache(4, cache_type=TreeCache)
def test_del_multi(self) -> None:
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("animal", "cat")), "mew")
self.assertEqual(cache.get(("vehicles", "car")), "vroom")
cache.del_multi(("animal",))
cache.del_multi(("animal",)) # type: ignore[arg-type]
self.assertEqual(len(cache), 2)
self.assertEqual(cache.get(("animal", "cat")), None)
self.assertEqual(cache.get(("animal", "dog")), None)
@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
def test_clear(self):
cache = LruCache(1)
def test_clear(self) -> None:
cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self):
cache = LruCache(10, "mycache")
def test_special_size(self) -> None:
cache: LruCache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
def test_get(self) -> None:
m = Mock()
cache = LruCache(1)
cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_multi_get(self):
def test_multi_get(self) -> None:
m = Mock()
cache = LruCache(1)
cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_set(self):
def test_set(self) -> None:
m = Mock()
cache = LruCache(1)
cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_pop(self):
def test_pop(self) -> None:
m = Mock()
cache = LruCache(1)
cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.pop("key")
self.assertEqual(m.call_count, 1)
def test_del_multi(self):
def test_del_multi(self) -> None:
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
cache = LruCache(4, cache_type=TreeCache)
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
cache.del_multi(("a",))
cache.del_multi(("a",)) # type: ignore[arg-type]
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
def test_clear(self):
def test_clear(self) -> None:
m1 = Mock()
m2 = Mock()
cache = LruCache(5)
cache: LruCache[str, str] = LruCache(5)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
def test_eviction(self):
def test_eviction(self) -> None:
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache = LruCache(2)
cache: LruCache[str, str] = LruCache(2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
def test_evict(self) -> None:
cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
cache["key1"] = []
self.assertEqual(len(cache), 0)
assert isinstance(cache.cache, dict)
cache.cache["key1"].drop_from_cache()
self.assertIsNone(
cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly."""
def default_config(self):
def default_config(self) -> JsonDict:
config = super().default_config()
config.setdefault("caches", {})["expiry_time"] = "30m"
return config
def test_evict(self):
def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs)
cache = LruCache(5, clock=self.hs.get_clock())
cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1
@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
}
)
@patch("synapse.util.caches.lrucache.get_jemalloc_stats")
def test_evict_memory(self, jemalloc_interface) -> None:
def test_evict_memory(self, jemalloc_interface: Mock) -> None:
mock_jemalloc_class = Mock(spec=JemallocStats)
jemalloc_interface.return_value = mock_jemalloc_class
@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs)
cache = LruCache(4, clock=self.hs.get_clock())
cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
cache["key1"] = 1
cache["key2"] = 2

View file

@ -21,14 +21,14 @@ from tests.unittest import TestCase
class MacaroonGeneratorTestCase(TestCase):
def setUp(self):
def setUp(self) -> None:
self.reactor, hs_clock = get_clock()
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
self.other_macaroon_generator = MacaroonGenerator(
hs_clock, "tesths", b"anothersecretkey"
)
def test_guest_access_token(self):
def test_guest_access_token(self) -> None:
"""Test the generation and verification of guest access tokens"""
token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
user_id = self.macaroon_generator.verify_guest_token(token)
@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_guest_token(token)
def test_delete_pusher_token(self):
def test_delete_pusher_token(self) -> None:
"""Test the generation and verification of delete_pusher tokens"""
token = self.macaroon_generator.generate_delete_pusher_token(
"@user:tesths", "m.mail", "john@example.com"
@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
def test_oidc_session_token(self):
def test_oidc_session_token(self) -> None:
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
session_data = OidcSessionData(

View file

@ -13,16 +13,19 @@
# limitations under the License.
from typing import Optional
from twisted.internet.defer import Deferred
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import get_clock
from tests.server import ThreadedMemoryReactorClock, get_clock
from tests.unittest import TestCase
from tests.utils import default_config
class FederationRateLimiterTestCase(TestCase):
def test_ratelimit(self):
def test_ratelimit(self) -> None:
"""A simple test with the default values"""
reactor, clock = get_clock()
rc_config = build_rc_config()
@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
# shouldn't block
self.successResultOf(d1)
def test_concurrent_limit(self):
def test_concurrent_limit(self) -> None:
"""Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
cm2.__exit__(None, None, None)
self.successResultOf(d3)
def test_sleep_limit(self):
def test_sleep_limit(self) -> None:
"""Test what happens when we hit the sleep limit"""
reactor, clock = get_clock()
rc_config = build_rc_config(
@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase):
self.assertAlmostEqual(sleep_time, 500, places=3)
def _await_resolution(reactor, d):
def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
"""advance the clock until the deferred completes.
Returns the number of milliseconds it took to complete.
@ -90,7 +93,7 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
def build_rc_config(settings: Optional[dict] = None):
def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
config_dict = default_config("test")
config_dict.update(settings or {})
config = HomeServerConfig()

View file

@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
def test_new_destination(self) -> None:
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self):
def test_limiter(self) -> None:
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main

View file

@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
acquired_d: "Deferred[None]" = Deferred()
unblock_d: "Deferred[None]" = Deferred()
async def reader_or_writer():
async def reader_or_writer() -> str:
async with read_or_write(key):
acquired_d.callback(None)
await unblock_d
@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
def test_rwlock(self):
def test_rwlock(self) -> None:
rwlock = ReadWriteLock()
key = "key"
@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
self.assertTrue(acquired_d.called)
def test_lock_handoff_to_nonblocking_writer(self):
def test_lock_handoff_to_nonblocking_writer(self) -> None:
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
def test_cancellation_while_holding_read_lock(self):
def test_cancellation_while_holding_read_lock(self) -> None:
"""Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is
@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write completed", self.successResultOf(writer_d))
def test_cancellation_while_holding_write_lock(self):
def test_cancellation_while_holding_write_lock(self) -> None:
"""Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is
@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("read completed", self.successResultOf(reader_d))
def test_cancellation_while_waiting_for_read_lock(self):
def test_cancellation_while_waiting_for_read_lock(self) -> None:
"""Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader:
@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
def test_cancellation_while_waiting_for_write_lock(self):
def test_cancellation_while_waiting_for_write_lock(self) -> None:
"""Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer:

View file

@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Tests for StreamChangeCache.
"""
def test_prefilled_cache(self):
def test_prefilled_cache(self) -> None:
"""
Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in.
@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
def test_has_entity_changed(self):
def test_has_entity_changed(self) -> None:
"""
StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities.
@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
def test_entity_has_changed_pops_off_start(self):
def test_entity_has_changed_pops_off_start(self) -> None:
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
)
self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self):
def test_get_all_entities_changed(self) -> None:
"""
StreamChangeCache.get_all_entities_changed will return all changed
entities since the given position. If the position is before the start
@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2)
def test_has_any_entity_changed(self):
def test_has_any_entity_changed(self) -> None:
"""
StreamChangeCache.has_any_entity_changed will return True if any
entities have been changed since the provided stream position, and
@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
def test_get_entities_changed(self):
def test_get_entities_changed(self) -> None:
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net"},
)
def test_max_pos(self):
def test_max_pos(self) -> None:
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
recent point where the entity could have changed. If the entity is not

View file

@ -19,7 +19,7 @@ from .. import unittest
class StringUtilsTestCase(unittest.TestCase):
def test_client_secret_regex(self):
def test_client_secret_regex(self) -> None:
"""Ensure that client_secret does not contain illegal characters"""
good = [
"abcde12345",
@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret)
def test_base62_encode(self):
def test_base62_encode(self) -> None:
self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))

View file

@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
class CanonicaliseEmailTests(HomeserverTestCase):
def test_no_at(self):
def test_no_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("address-without-at.bar")
def test_two_at(self):
def test_two_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("foo@foo@test.bar")
def test_bad_format(self):
def test_bad_format(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("user@bad.example.net@good.example.com")
def test_valid_format(self):
def test_valid_format(self) -> None:
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
def test_domain_to_lower(self):
def test_domain_to_lower(self) -> None:
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
def test_domain_with_umlaut(self):
def test_domain_with_umlaut(self) -> None:
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
def test_address_casefold(self):
def test_address_casefold(self) -> None:
self.assertEqual(
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
)
def test_address_trim(self):
def test_address_trim(self) -> None:
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")

View file

@ -19,7 +19,7 @@ from .. import unittest
class TreeCacheTestCase(unittest.TestCase):
def test_get_set_onelevel(self):
def test_get_set_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 2)
def test_pop_onelevel(self):
def test_pop_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 1)
def test_get_set_twolevel(self):
def test_get_set_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b", "a")), "BA")
self.assertEqual(len(cache), 3)
def test_pop_twolevel(self):
def test_pop_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.pop(("b", "a")), None)
self.assertEqual(len(cache), 1)
def test_pop_mixedlevel(self):
def test_pop_mixedlevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
def test_clear(self):
def test_clear(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEqual(len(cache), 0)
def test_contains(self):
def test_contains(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
self.assertTrue(("a",) in cache)

View file

@ -18,8 +18,8 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self):
wheel = WheelTimer(bucket_size=5)
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self):
wheel = WheelTimer(bucket_size=5)
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self):
wheel = WheelTimer(bucket_size=5)
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_multi(self):
wheel = WheelTimer(bucket_size=5)
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()