Add type hints to tests/rest/client (#12094)

* Add type hints to `tests/rest/client`

* update `mypy.ini`

* newsfile

* add `test_register.py`
This commit is contained in:
Dirk Klimpel 2022-02-28 19:59:00 +01:00 committed by GitHub
parent 7754af24ab
commit 952efd0bca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 64 deletions

1
changelog.d/12094.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View file

@ -75,10 +75,7 @@ exclude = (?x)
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_account.py
|tests/rest/client/test_events.py
|tests/rest/client/test_filter.py
|tests/rest/client/test_groups.py
|tests/rest/client/test_register.py
|tests/rest/client/test_report_event.py
|tests/rest/client/test_rooms.py
|tests/rest/client/test_third_party_rules.py

View file

@ -16,8 +16,12 @@
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["enable_registration_captcha"] = False
@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
hs.get_federation_handler = Mock()
hs.get_federation_handler = Mock() # type: ignore[assignment]
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account
self.user_id = self.register_user("sid1", "pass")
@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
self.other_user = self.register_user("other2", "pass")
self.other_token = self.login(self.other_user, "pass")
def test_stream_basic_permissions(self):
def test_stream_basic_permissions(self) -> None:
# invalid token, expect 401
# note: this is in violation of the original v1 spec, which expected
# 403. However, since the v1 spec no longer exists and the v1
@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
self.assertTrue("start" in channel.json_body)
self.assertTrue("end" in channel.json_body)
def test_stream_room_permissions(self):
def test_stream_room_permissions(self) -> None:
room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
self.helper.send(room_id, tok=self.other_token)
@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# left to room (expect no content for room)
def TODO_test_stream_items(self):
def TODO_test_stream_items(self) -> None:
# new user, no content
# join room, expect 1 item (join)
@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def prepare(self, hs, reactor, clock):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register an account
self.user_id = self.register_user("sid1", "pass")
@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
def test_get_event_via_events(self):
def test_get_event_via_events(self) -> None:
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]

View file

@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets, groups.register_servlets]
@override_config({"enable_group_creation": True})
def test_rooms_limited_by_visibility(self):
def test_rooms_limited_by_visibility(self) -> None:
group_id = "+spqr:test"
# Alice creates a group

View file

@ -16,15 +16,21 @@
import datetime
import json
import os
from typing import Any, Dict, List, Tuple
import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
from synapse.server import HomeServer
from synapse.storage._base import db_to_json
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
]
url = b"/_matrix/client/r0/register"
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["allow_guest_access"] = True
return config
def test_POST_appservice_registration_valid(self):
def test_POST_appservice_registration_valid(self) -> None:
user_id = "@as_user_kermit:test"
as_token = "i_am_an_app_service"
@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_appservice_registration_no_type(self):
def test_POST_appservice_registration_no_type(self) -> None:
as_token = "i_am_an_app_service"
appservice = ApplicationService(
@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"400", channel.result)
def test_POST_appservice_registration_invalid(self):
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
request_data = json.dumps(
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
def test_POST_bad_password(self) -> None:
request_data = json.dumps({"username": "kermit", "password": 666})
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self):
def test_POST_bad_username(self) -> None:
request_data = json.dumps({"username": 777, "password": "monkey"})
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self):
def test_POST_user_valid(self) -> None:
user_id = "@kermit:test"
device_id = "frogfone"
params = {
@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
def test_POST_disabled_registration(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
def test_POST_guest_registration(self) -> None:
self.hs.config.key.macaroon_secret_key = "test"
self.hs.config.registration.allow_guest_access = True
@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self):
def test_POST_disabled_guest_registration(self) -> None:
self.hs.config.registration.allow_guest_access = False
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self):
def test_POST_ratelimiting_guest(self) -> None:
for i in range(0, 6):
url = self.url + b"?kind=guest"
channel = self.make_request(b"POST", url, b"{}")
@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self):
def test_POST_ratelimiting(self) -> None:
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self):
def test_POST_registration_requires_token(self) -> None:
username = "kermit"
device_id = "frogfone"
token = "abcd"
@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
)
)
params = {
params: JsonDict = {
"username": username,
"password": "monkey",
"device_id": device_id,
@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["pending"], 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self):
params = {
def test_POST_registration_token_invalid(self) -> None:
params: JsonDict = {
"username": "kermit",
"password": "monkey",
}
@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_limit_uses(self):
def test_POST_registration_token_limit_uses(self) -> None:
token = "abcd"
store = self.hs.get_datastores().main
# Create token that can be used once
@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
)
)
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["completed"], [])
@override_config({"registration_requires_token": True})
def test_POST_registration_token_expiry(self):
def test_POST_registration_token_expiry(self) -> None:
token = "abcd"
now = self.hs.get_clock().time_msec()
store = self.hs.get_datastores().main
@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
)
)
params = {"username": "kermit", "password": "monkey"}
params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry(self):
def test_POST_registration_token_session_expiry(self) -> None:
"""Test `pending` is decremented when an uncompleted session expires."""
token = "abcd"
store = self.hs.get_datastores().main
@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Do 2 requests without auth to get two session IDs
params1 = {"username": "bert", "password": "monkey"}
params2 = {"username": "ernie", "password": "monkey"}
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_session_expiry_deleted_token(self):
def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
"""Test session expiry doesn't break when the token is deleted.
1. Start but don't complete UIA with a registration token
@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Do request without auth to get a session ID
params = {"username": "kermit", "password": "monkey"}
params: JsonDict = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
)
def test_advertised_flows(self):
def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self):
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_advertised_flows_no_msisdn_email_required(self):
def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_request_token_existing_email_inhibit_error(self):
def test_request_token_existing_email_inhibit_error(self) -> None:
"""Test that requesting a token via this endpoint doesn't leak existing
associations if configured that way.
"""
@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
},
}
)
def test_reject_invalid_email(self):
def test_reject_invalid_email(self) -> None:
"""Check that bad emails are rejected"""
# Test for email with multiple @
@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"inhibit_user_in_use_error": True,
}
)
def test_inhibit_user_in_use_error(self):
def test_inhibit_user_in_use_error(self) -> None:
"""Tests that the 'inhibit_user_in_use_error' configuration flag behaves
correctly.
"""
@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
account_validity.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Test for account expiring after a week.
config["enable_registration"] = True
@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
return self.hs
def test_validity_period(self):
def test_validity_period(self) -> None:
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
def test_manual_renewal(self):
def test_manual_renewal(self) -> None:
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
def test_logging_out_expired_user(self):
def test_logging_out_expired_user(self) -> None:
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
account.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Test for account expiring after a week and renewal emails being sent 2
@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config)
async def sendmail(*args, **kwargs):
async def sendmail(*args: Any, **kwargs: Any) -> None:
self.email_attempts.append((args, kwargs))
self.email_attempts = []
self.email_attempts: List[Tuple[Any, Any]] = []
self.hs.get_send_email_handler()._sendmail = sendmail
self.store = self.hs.get_datastores().main
return self.hs
def test_renewal_email(self):
def test_renewal_email(self) -> None:
self.email_attempts = []
(user_id, tok) = self.create_user()
@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
channel.result["body"], expected_html.encode("utf8"), channel.result
)
def test_manual_email_send(self):
def test_manual_email_send(self) -> None:
self.email_attempts = []
(user_id, tok) = self.create_user()
@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1)
def test_deactivated_user(self):
def test_deactivated_user(self) -> None:
self.email_attempts = []
(user_id, tok) = self.create_user()
@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 0)
def create_user(self):
def create_user(self) -> Tuple[str, str]:
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
)
return user_id, tok
def test_manual_email_send_expired_account(self):
def test_manual_email_send_expired_account(self) -> None:
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.validity_period = 10
self.max_delta = self.validity_period * 10.0 / 100.0
@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
return self.hs
def test_background_job(self):
def test_background_job(self) -> None:
"""
Tests the same thing as test_background_job, except that it sets the
startup_job_max_delta parameter and checks that the expiration date is within the
@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
url = "/_matrix/client/v1/register/m.login.registration_token/validity"
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["registration_requires_token"] = True
return config
def test_GET_token_valid(self):
def test_GET_token_valid(self) -> None:
token = "abcd"
store = self.hs.get_datastores().main
self.get_success(
@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self):
def test_GET_token_invalid(self) -> None:
token = "1234"
channel = self.make_request(
b"GET",
@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
@override_config(
{"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
)
def test_GET_ratelimiting(self):
def test_GET_ratelimiting(self) -> None:
token = "1234"
for i in range(0, 6):