0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Add missing type hints for tests.events. (#14904)

This commit is contained in:
Patrick Cloke 2023-01-25 15:14:03 -05:00 committed by GitHub
parent 8bc5d1406c
commit 3c3ba31507
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 64 deletions

1
changelog.d/14904.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -35,8 +35,6 @@ exclude = (?x)
|tests/api/test_auth.py
|tests/app/test_openid_listener.py
|tests/appservice/test_scheduler.py
|tests/events/test_presence_router.py
|tests/events/test_utils.py
|tests/federation/test_federation_catch_up.py
|tests/federation/test_federation_sender.py
|tests/handlers/test_typing.py
@ -86,6 +84,9 @@ disallow_untyped_defs = True
[mypy-tests.crypto.*]
disallow_untyped_defs = True
[mypy-tests.events.*]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

View file

@ -605,10 +605,11 @@ class EventClientSerializer:
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
def copy_and_fixup_power_levels_contents(
old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
old_power_levels: PowerLevelsContent,
) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way.

View file

@ -16,6 +16,8 @@ from unittest.mock import Mock
import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.federation.units import Transaction
@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, presence, room
from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
from tests.unittest import FederatingHomeserverTestCase, override_config
@attr.s
@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule:
}
return users_to_state
async def get_interested_users(
self, user_id: str
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS
@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule:
# Initialise a typed config object
config = PresenceRouterTestConfig()
config.users_who_should_receive_all_presence = config_dict.get(
users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
assert isinstance(users_who_should_receive_all_presence, list)
config.users_who_should_receive_all_presence = (
users_who_should_receive_all_presence
)
return config
@ -96,9 +103,7 @@ class PresenceRouterTestModule:
}
return users_to_state
async def get_interested_users(
self, user_id: str
) -> Union[Set[str], PresenceRouter.ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
if user_id in self._config.users_who_should_receive_all_presence:
return PresenceRouter.ALL_USERS
@ -118,9 +123,14 @@ class PresenceRouterTestModule:
# Initialise a typed config object
config = PresenceRouterTestConfig()
config.users_who_should_receive_all_presence = config_dict.get(
users_who_should_receive_all_presence = config_dict.get(
"users_who_should_receive_all_presence"
)
assert isinstance(users_who_should_receive_all_presence, list)
config.users_who_should_receive_all_presence = (
users_who_should_receive_all_presence
)
return config
@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
presence.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
return hs
def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()
@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
)
def test_receiving_all_presence_legacy(self):
def test_receiving_all_presence_legacy(self) -> None:
self.receiving_all_presence_test_body()
@override_config(
@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
],
}
)
def test_receiving_all_presence(self):
def test_receiving_all_presence(self) -> None:
self.receiving_all_presence_test_body()
def receiving_all_presence_test_body(self):
def receiving_all_presence_test_body(self) -> None:
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
)
def test_send_local_online_presence_to_with_module_legacy(self):
def test_send_local_online_presence_to_with_module_legacy(self) -> None:
self.send_local_online_presence_to_with_module_test_body()
@override_config(
@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
],
}
)
def test_send_local_online_presence_to_with_module(self):
def test_send_local_online_presence_to_with_module(self) -> None:
self.send_local_online_presence_to_with_module_test_body()
def send_local_online_presence_to_with_module_test_body(self):
def send_local_online_presence_to_with_module_test_body(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
continue
# EDUs can contain multiple presence updates
for presence_update in edu["content"]["push"]:
for presence_edu in edu["content"]["push"]:
# Check for presence updates that contain the user IDs we're after
found_users.add(presence_update["user_id"])
found_users.add(presence_edu["user_id"])
# Ensure that no offline states are being sent out
self.assertNotEqual(presence_update["presence"], "offline")
self.assertNotEqual(presence_edu["presence"], "offline")
self.assertEqual(found_users, expected_users)
def send_presence_update(
testcase: TestCase,
testcase: FederatingHomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@ -479,7 +491,7 @@ def send_presence_update(
def sync_presence(
testcase: TestCase,
testcase: FederatingHomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
@ -500,7 +512,7 @@ def sync_presence(
requester = create_requester(user_id)
sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success(
testcase.sync_handler.wait_for_sync_for_user(
testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester, sync_config, since_token
)
)

View file

@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils.event_injection import create_event
@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase):
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)
def test_serialize_deserialize_msg(self):
def test_serialize_deserialize_msg(self) -> None:
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_no_prev(self):
def test_serialize_deserialize_state_no_prev(self) -> None:
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_prev(self):
def test_serialize_deserialize_state_prev(self) -> None:
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase):
self._check_serialize_deserialize(event, context)
def _check_serialize_deserialize(self, event, context):
def _check_serialize_deserialize(
self, event: EventBase, context: EventContext
) -> None:
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self._storage_controllers, serialized)

View file

@ -13,21 +13,24 @@
# limitations under the License.
import unittest as stdlib_unittest
from typing import Any, List, Mapping, Optional
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import (
PowerLevelsContent,
SerializeEventConfig,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
serialize_event,
)
from synapse.types import JsonDict
from synapse.util.frozenutils import freeze
def MockEvent(**kwargs):
def MockEvent(**kwargs: Any) -> EventBase:
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
class PruneEventTestCase(stdlib_unittest.TestCase):
def run_test(self, evdict, matchdict, **kwargs):
def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None:
"""
Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted.
@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
)
def test_minimal(self):
def test_minimal(self) -> None:
self.run_test(
{"type": "A", "event_id": "$test:domain"},
{
@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
},
)
def test_basic_keys(self):
def test_basic_keys(self) -> None:
"""Ensure that the keys that should be untouched are kept."""
# Note that some of the values below don't really make sense, but the
# pruning of events doesn't worry about the values of any fields (with
@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
def test_unsigned(self):
def test_unsigned(self) -> None:
"""Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
self.run_test(
{
@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
},
)
def test_content(self):
def test_content(self) -> None:
"""The content dictionary should be stripped in most cases."""
self.run_test(
{"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
},
)
def test_create(self):
def test_create(self) -> None:
"""Create events are partially redacted until MSC2176."""
self.run_test(
{
@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
def test_power_levels(self):
def test_power_levels(self) -> None:
"""Power level events keep a variety of content keys."""
self.run_test(
{
@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
def test_alias_event(self):
def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
self.run_test(
{
@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.V6,
)
def test_redacts(self):
def test_redacts(self) -> None:
"""Redaction events have no special behaviour until MSC2174/MSC2176."""
self.run_test(
@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.MSC2176,
)
def test_join_rules(self):
def test_join_rules(self) -> None:
"""Join rules events have changed behavior starting with MSC3083."""
self.run_test(
{
@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
room_version=RoomVersions.V8,
)
def test_member(self):
def test_member(self) -> None:
"""Member events have changed behavior starting with MSC3375."""
self.run_test(
{
@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev, fields):
def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
)
def test_event_fields_works_with_keys(self):
def test_event_fields_works_with_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{"room_id": "!foo:bar"},
)
def test_event_fields_works_with_nested_keys(self):
def test_event_fields_works_with_nested_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{"content": {"body": "A message"}},
)
def test_event_fields_works_with_dot_keys(self):
def test_event_fields_works_with_dot_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{"content": {"key.with.dots": {}}},
)
def test_event_fields_works_with_nested_dot_keys(self):
def test_event_fields_works_with_nested_dot_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
def test_event_fields_nops_with_unknown_keys(self):
def test_event_fields_nops_with_unknown_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{"content": {"foo": "bar"}},
)
def test_event_fields_nops_with_non_dict_keys(self):
def test_event_fields_nops_with_non_dict_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{},
)
def test_event_fields_nops_with_array_keys(self):
def test_event_fields_nops_with_array_keys(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
{},
)
def test_event_fields_all_fields_if_empty(self):
def test_event_fields_all_fields_if_empty(self) -> None:
self.assertEqual(
self.serialize(
MockEvent(
@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
},
)
def test_event_fields_fail_if_fields_not_str(self):
def test_event_fields_fail_if_fields_not_str(self) -> None:
with self.assertRaises(TypeError):
self.serialize(
MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item]
)
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None:
self.test_content = {
self.test_content: PowerLevelsContent = {
"ban": 50,
"events": {"m.room.name": 100, "m.room.power_levels": 100},
"events_default": 0,
@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
"users_default": 0,
}
def _test(self, input):
def _test(self, input: PowerLevelsContent) -> None:
a = copy_and_fixup_power_levels_contents(input)
self.assertEqual(a["ban"], 50)
assert isinstance(a["events"], Mapping)
self.assertEqual(a["events"]["m.room.name"], 100)
# make sure that changing the copy changes the copy and not the orig
@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
a["events"]["m.room.power_levels"] = 20
self.assertEqual(input["ban"], 50)
assert isinstance(input["events"], Mapping)
self.assertEqual(input["events"]["m.room.power_levels"], 100)
def test_unfrozen(self):
def test_unfrozen(self) -> None:
self._test(self.test_content)
def test_frozen(self):
def test_frozen(self) -> None:
input = freeze(self.test_content)
self._test(input)
def test_stringy_integers(self):
def test_stringy_integers(self) -> None:
"""String representations of decimal integers are converted to integers."""
input = {
input: PowerLevelsContent = {
"a": "100",
"b": {
"foo": 99,
@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def test_invalid_types_raise_type_error(self) -> None:
with self.assertRaises(TypeError):
copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type]
copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type]
copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item]
copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item]
def test_invalid_nesting_raises_type_error(self) -> None:
with self.assertRaises(TypeError):
copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})
copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item]