Add type hints to tests/rest/client (#12072)

This commit is contained in:
Dirk Klimpel 2022-02-24 19:56:38 +01:00 committed by GitHub
parent 2cc5ea933d
commit 54e74cc15f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 160 additions and 102 deletions

1
changelog.d/12072.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `tests/rest/client`.

View file

@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
import os import os
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.rest.consent import consent_resource from synapse.rest.consent import consent_resource
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
user_id = True user_id = True
hijack_auth = False hijack_auth = False
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["form_secret"] = "123abc" config["form_secret"] = "123abc"
@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config) hs = self.setup_test_homeserver(config=config)
return hs return hs
def test_render_public_consent(self): def test_render_public_consent(self) -> None:
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
channel = make_request( channel = make_request(
@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"/consent?v=1", "/consent?v=1",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
def test_accept_consent(self): def test_accept_consent(self) -> None:
""" """
A user can use the consent form to accept the terms. A user can use the consent form to accept the terms.
""" """
@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
# Get the version from the body, and whether we've consented # Get the version from the body, and whether we've consented
version, consented = channel.result["body"].decode("ascii").split(",") version, consented = channel.result["body"].decode("ascii").split(",")
@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
# Fetch the consent page, to get the consent version -- it should have # Fetch the consent page, to get the consent version -- it should have
# changed # changed
@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
# Get the version from the body, and check that it's the version we # Get the version from the body, and check that it's the version we
# agreed to, and that we've consented to it. # agreed to, and that we've consented to it.

View file

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from synapse.rest import admin, devices, room, sync from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, login, register from synapse.rest.client import account, login, register
@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
devices.register_servlets, devices.register_servlets,
] ]
def test_receiving_local_device_list_changes(self): def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list """Tests that a local users that share a room receive each other's device list
changes. changes.
""" """
@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
}, },
access_token=alice_access_token, access_token=alice_access_token,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync contains the updated device list. # Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the # If not, the client would only receive the device list update on the
@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
) )
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self): def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not """Tests a local users DO NOT receive device updates from each other if they do not
share a room. share a room.
""" """
@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
"/sync", "/sync",
access_token=bob_access_token, access_token=bob_access_token,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
next_batch_token = channel.json_body["next_batch"] next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up, # ...and then an incremental sync. This should block until the sync stream is woken up,
@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
}, },
access_token=alice_access_token, access_token=alice_access_token,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list. # Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result() bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) self.assertEqual(
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", [] "changed", []

View file

@ -11,9 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventContentFields, EventTypes from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import room from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["enable_ephemeral_messages"] = True config["enable_ephemeral_messages"] = True
@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config) self.hs = self.setup_test_homeserver(config=config)
return self.hs return self.hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
def test_message_expiry_no_delay(self): def test_message_expiry_no_delay(self) -> None:
"""Tests that sending a message sent with a m.self_destruct_after field set to the """Tests that sending a message sent with a m.self_destruct_after field set to the
past results in that event being deleted right away. past results in that event being deleted right away.
""" """
@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_content = self.get_event(self.room_id, event_id)["content"] event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content) self.assertFalse(bool(event_content), event_content)
def test_message_expiry_delay(self): def test_message_expiry_delay(self) -> None:
"""Tests that sending a message with a m.self_destruct_after field set to the """Tests that sending a message with a m.self_destruct_after field set to the
future results in that event not being deleted right away, but advancing the future results in that event not being deleted right away, but advancing the
clock to after that expiry timestamp causes the event to be deleted. clock to after that expiry timestamp causes the event to be deleted.
@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
event_content = self.get_event(self.room_id, event_id)["content"] event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content) self.assertFalse(bool(event_content), event_content)
def get_event(self, room_id, event_id, expected_code=200): def get_event(
self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK
) -> JsonDict:
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
channel = self.make_request("GET", url) channel = self.make_request("GET", url)

View file

@ -13,9 +13,14 @@
# limitations under the License. # limitations under the License.
import json import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["enable_3pid_lookup"] = False config["enable_3pid_lookup"] = False
@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_3pid_lookup_disabled(self): def test_3pid_lookup_disabled(self) -> None:
self.hs.config.registration.enable_3pid_lookup = False self.hs.config.registration.enable_3pid_lookup = False
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"] room_id = channel.json_body["room_id"]
params = { params = {
@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
b"POST", request_url, request_data, access_token=tok b"POST", request_url, request_data, access_token=tok
) )
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)

View file

@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def test_rejects_device_id_ice_key_outside_of_list(self): def test_rejects_device_id_ice_key_outside_of_list(self) -> None:
self.register_user("alice", "wonderland") self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland") alice_token = self.login("alice", "wonderland")
bob = self.register_user("bob", "uncle") bob = self.register_user("bob", "uncle")
@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
channel.result, channel.result,
) )
def test_rejects_device_key_given_as_map_to_bool(self): def test_rejects_device_key_given_as_map_to_bool(self) -> None:
self.register_user("alice", "wonderland") self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland") alice_token = self.login("alice", "wonderland")
bob = self.register_user("bob", "uncle") bob = self.register_user("bob", "uncle")
@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
channel.result, channel.result,
) )
def test_requires_device_key(self): def test_requires_device_key(self) -> None:
"""`device_keys` is required. We should complain if it's missing.""" """`device_keys` is required. We should complain if it's missing."""
self.register_user("alice", "wonderland") self.register_user("alice", "wonderland")
alice_token = self.login("alice", "wonderland") alice_token = self.login("alice", "wonderland")

View file

@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
import json import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import account, login, password_policy, register from synapse.rest.client import account, login, password_policy, register
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
account.register_servlets, account.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.register_url = "/_matrix/client/r0/register" self.register_url = "/_matrix/client/r0/register"
self.policy = { self.policy = {
"enabled": True, "enabled": True,
@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config) hs = self.setup_test_homeserver(config=config)
return hs return hs
def test_get_policy(self): def test_get_policy(self) -> None:
"""Tests if the /password_policy endpoint returns the configured policy.""" """Tests if the /password_policy endpoint returns the configured policy."""
channel = self.make_request("GET", "/_matrix/client/r0/password_policy") channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
channel.result, channel.result,
) )
def test_password_too_short(self): def test_password_too_short(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "shorty"}) request_data = json.dumps({"username": "kermit", "password": "shorty"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.PASSWORD_TOO_SHORT, Codes.PASSWORD_TOO_SHORT,
channel.result, channel.result,
) )
def test_password_no_digit(self): def test_password_no_digit(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.PASSWORD_NO_DIGIT, Codes.PASSWORD_NO_DIGIT,
channel.result, channel.result,
) )
def test_password_no_symbol(self): def test_password_no_symbol(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.PASSWORD_NO_SYMBOL, Codes.PASSWORD_NO_SYMBOL,
channel.result, channel.result,
) )
def test_password_no_uppercase(self): def test_password_no_uppercase(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.PASSWORD_NO_UPPERCASE, Codes.PASSWORD_NO_UPPERCASE,
channel.result, channel.result,
) )
def test_password_no_lowercase(self): def test_password_no_lowercase(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.PASSWORD_NO_LOWERCASE, Codes.PASSWORD_NO_LOWERCASE,
channel.result, channel.result,
) )
def test_password_compliant(self): def test_password_compliant(self) -> None:
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data) channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has # Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows. # responded with a list of registration flows.
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_password_change(self): def test_password_change(self) -> None:
"""This doesn't test every possible use case, only that hitting /account/password """This doesn't test every possible use case, only that hitting /account/password
triggers the password validation code. triggers the password validation code.
""" """
@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)

View file

@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room, sync from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
return self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# register a room admin, moderator and regular user # register a room admin, moderator and regular user
self.admin_user_id = self.register_user("admin", "pass") self.admin_user_id = self.register_user("admin", "pass")
self.admin_access_token = self.login("admin", "pass") self.admin_access_token = self.login("admin", "pass")
@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
tok=self.admin_access_token, tok=self.admin_access_token,
) )
def test_non_admins_cannot_enable_room_encryption(self): def test_non_admins_cannot_enable_room_encryption(self) -> None:
# have the mod try to enable room encryption # have the mod try to enable room encryption
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.encryption", "m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"}, {"algorithm": "m.megolm.v1.aes-sha2"},
tok=self.user_access_token, tok=self.user_access_token,
expect_code=403, # expect failure expect_code=HTTPStatus.FORBIDDEN, # expect failure
) )
def test_non_admins_cannot_send_server_acl(self): def test_non_admins_cannot_send_server_acl(self) -> None:
# have the mod try to send a server ACL # have the mod try to send a server ACL
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"], "deny": ["*.evil.com", "evil.com"],
}, },
tok=self.mod_access_token, tok=self.mod_access_token,
expect_code=403, # expect failure expect_code=HTTPStatus.FORBIDDEN, # expect failure
) )
# have the user try to send a server ACL # have the user try to send a server ACL
@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"], "deny": ["*.evil.com", "evil.com"],
}, },
tok=self.user_access_token, tok=self.user_access_token,
expect_code=403, # expect failure expect_code=HTTPStatus.FORBIDDEN, # expect failure
) )
def test_non_admins_cannot_tombstone_room(self): def test_non_admins_cannot_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room" # Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as( self.upgraded_room_id = self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token self.admin_user_id, tok=self.admin_access_token
@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"replacement_room": self.upgraded_room_id, "replacement_room": self.upgraded_room_id,
}, },
tok=self.mod_access_token, tok=self.mod_access_token,
expect_code=403, # expect failure expect_code=HTTPStatus.FORBIDDEN, # expect failure
) )
# have the user try to send a tombstone event # have the user try to send a tombstone event
@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase):
expect_code=403, # expect failure expect_code=403, # expect failure
) )
def test_admins_can_enable_room_encryption(self): def test_admins_can_enable_room_encryption(self) -> None:
# have the admin try to enable room encryption # have the admin try to enable room encryption
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
"m.room.encryption", "m.room.encryption",
{"algorithm": "m.megolm.v1.aes-sha2"}, {"algorithm": "m.megolm.v1.aes-sha2"},
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=200, # expect success expect_code=HTTPStatus.OK, # expect success
) )
def test_admins_can_send_server_acl(self): def test_admins_can_send_server_acl(self) -> None:
# have the admin try to send a server ACL # have the admin try to send a server ACL
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"deny": ["*.evil.com", "evil.com"], "deny": ["*.evil.com", "evil.com"],
}, },
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=200, # expect success expect_code=HTTPStatus.OK, # expect success
) )
def test_admins_can_tombstone_room(self): def test_admins_can_tombstone_room(self) -> None:
# Create another room that will serve as our "upgraded room" # Create another room that will serve as our "upgraded room"
self.upgraded_room_id = self.helper.create_room_as( self.upgraded_room_id = self.helper.create_room_as(
self.admin_user_id, tok=self.admin_access_token self.admin_user_id, tok=self.admin_access_token
@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
"replacement_room": self.upgraded_room_id, "replacement_room": self.upgraded_room_id,
}, },
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=200, # expect success expect_code=HTTPStatus.OK, # expect success
) )
def test_cannot_set_string_power_levels(self): def test_cannot_set_string_power_levels(self) -> None:
room_power_levels = self.helper.get_state( room_power_levels = self.helper.get_state(
self.room_id, self.room_id,
"m.room.power_levels", "m.room.power_levels",
@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels", "m.room.power_levels",
room_power_levels, room_power_levels,
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=400, # expect failure expect_code=HTTPStatus.BAD_REQUEST, # expect failure
) )
self.assertEqual( self.assertEqual(
@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
body, body,
) )
def test_cannot_set_unsafe_large_power_levels(self): def test_cannot_set_unsafe_large_power_levels(self) -> None:
room_power_levels = self.helper.get_state( room_power_levels = self.helper.get_state(
self.room_id, self.room_id,
"m.room.power_levels", "m.room.power_levels",
@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels", "m.room.power_levels",
room_power_levels, room_power_levels,
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=400, # expect failure expect_code=HTTPStatus.BAD_REQUEST, # expect failure
) )
self.assertEqual( self.assertEqual(
@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
body, body,
) )
def test_cannot_set_unsafe_small_power_levels(self): def test_cannot_set_unsafe_small_power_levels(self) -> None:
room_power_levels = self.helper.get_state( room_power_levels = self.helper.get_state(
self.room_id, self.room_id,
"m.room.power_levels", "m.room.power_levels",
@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
"m.room.power_levels", "m.room.power_levels",
room_power_levels, room_power_levels,
tok=self.admin_access_token, tok=self.admin_access_token,
expect_code=400, # expect failure expect_code=HTTPStatus.BAD_REQUEST, # expect failure
) )
self.assertEqual( self.assertEqual(

View file

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.rest.client import presence from synapse.rest.client import presence
from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
servlets = [presence.register_servlets] servlets = [presence.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler) presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None) presence_handler.set_state.return_value = defer.succeed(None)
@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
return hs return hs
def test_put_presence(self): def test_put_presence(self) -> None:
""" """
PUT to the status endpoint with use_presence enabled will call PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler. set_state on the presence handler.
@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase):
"PUT", "/presence/%s/status" % (self.user_id,), body "PUT", "/presence/%s/status" % (self.user_id,), body
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
@unittest.override_config({"use_presence": False}) @unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self): def test_put_presence_disabled(self) -> None:
""" """
PUT to the status endpoint with use_presence disabled will NOT call PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler. set_state on the presence handler.
@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase):
"PUT", "/presence/%s/status" % (self.user_id,), body "PUT", "/presence/%s/status" % (self.user_id,), body
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)

View file

@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
return room_id, event_id_a, event_id_b, event_id_c return room_id, event_id_a, event_id_b, event_id_c
@unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
def test_same_state_groups_for_whole_historical_batch(self): def test_same_state_groups_for_whole_historical_batch(self) -> None:
"""Make sure that when using the `/batch_send` endpoint to import a """Make sure that when using the `/batch_send` endpoint to import a
bunch of historical messages, it re-uses the same `state_group` across bunch of historical messages, it re-uses the same `state_group` across
the whole batch. This is an easy optimization to make sure we're getting the whole batch. This is an easy optimization to make sure we're getting

View file

@ -19,6 +19,7 @@ import json
import re import re
import time import time
import urllib.parse import urllib.parse
from http import HTTPStatus
from typing import ( from typing import (
Any, Any,
AnyStr, AnyStr,
@ -89,7 +90,7 @@ class RestHelper:
is_public: Optional[bool] = None, is_public: Optional[bool] = None,
room_version: Optional[str] = None, room_version: Optional[str] = None,
tok: Optional[str] = None, tok: Optional[str] = None,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
extra_content: Optional[Dict] = None, extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> Optional[str]: ) -> Optional[str]:
@ -137,12 +138,19 @@ class RestHelper:
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
if expect_code == 200: if expect_code == HTTPStatus.OK:
return channel.json_body["room_id"] return channel.json_body["room_id"]
else: else:
return None return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): def invite(
self,
room: Optional[str] = None,
src: Optional[str] = None,
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership( self.change_membership(
room=room, room=room,
src=src, src=src,
@ -156,7 +164,7 @@ class RestHelper:
self, self,
room: str, room: str,
user: Optional[str] = None, user: Optional[str] = None,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None, tok: Optional[str] = None,
appservice_user_id: Optional[str] = None, appservice_user_id: Optional[str] = None,
) -> None: ) -> None:
@ -170,7 +178,14 @@ class RestHelper:
expect_code=expect_code, expect_code=expect_code,
) )
def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None): def knock(
self,
room: Optional[str] = None,
user: Optional[str] = None,
reason: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
temp_id = self.auth_user_id temp_id = self.auth_user_id
self.auth_user_id = user self.auth_user_id = user
path = "/knock/%s" % room path = "/knock/%s" % room
@ -199,7 +214,13 @@ class RestHelper:
self.auth_user_id = temp_id self.auth_user_id = temp_id
def leave(self, room=None, user=None, expect_code=200, tok=None): def leave(
self,
room: Optional[str] = None,
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership( self.change_membership(
room=room, room=room,
src=user, src=user,
@ -209,7 +230,7 @@ class RestHelper:
expect_code=expect_code, expect_code=expect_code,
) )
def ban(self, room: str, src: str, targ: str, **kwargs: object): def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None:
"""A convenience helper: `change_membership` with `membership` preset to "ban".""" """A convenience helper: `change_membership` with `membership` preset to "ban"."""
self.change_membership( self.change_membership(
room=room, room=room,
@ -228,7 +249,7 @@ class RestHelper:
extra_data: Optional[dict] = None, extra_data: Optional[dict] = None,
tok: Optional[str] = None, tok: Optional[str] = None,
appservice_user_id: Optional[str] = None, appservice_user_id: Optional[str] = None,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None, expect_errcode: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -294,13 +315,13 @@ class RestHelper:
def send( def send(
self, self,
room_id, room_id: str,
body=None, body: Optional[str] = None,
txn_id=None, txn_id: Optional[str] = None,
tok=None, tok: Optional[str] = None,
expect_code=200, expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
): ) -> JsonDict:
if body is None: if body is None:
body = "body_text_here" body = "body_text_here"
@ -318,14 +339,14 @@ class RestHelper:
def send_event( def send_event(
self, self,
room_id, room_id: str,
type, type: str,
content: Optional[dict] = None, content: Optional[dict] = None,
txn_id=None, txn_id: Optional[str] = None,
tok=None, tok: Optional[str] = None,
expect_code=200, expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
): ) -> JsonDict:
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@ -358,10 +379,10 @@ class RestHelper:
event_type: str, event_type: str,
body: Optional[Dict[str, Any]], body: Optional[Dict[str, Any]],
tok: str, tok: str,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
state_key: str = "", state_key: str = "",
method: str = "GET", method: str = "GET",
) -> Dict: ) -> JsonDict:
"""Read or write some state from a given room """Read or write some state from a given room
Args: Args:
@ -410,9 +431,9 @@ class RestHelper:
room_id: str, room_id: str,
event_type: str, event_type: str,
tok: str, tok: str,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
state_key: str = "", state_key: str = "",
): ) -> JsonDict:
"""Gets some state from a room """Gets some state from a room
Args: Args:
@ -438,9 +459,9 @@ class RestHelper:
event_type: str, event_type: str,
body: Dict[str, Any], body: Dict[str, Any],
tok: str, tok: str,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
state_key: str = "", state_key: str = "",
): ) -> JsonDict:
"""Set some state in a room """Set some state in a room
Args: Args:
@ -467,8 +488,8 @@ class RestHelper:
image_data: bytes, image_data: bytes,
tok: str, tok: str,
filename: str = "test.png", filename: str = "test.png",
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
) -> dict: ) -> JsonDict:
"""Upload a piece of test media to the media repo """Upload a piece of test media to the media repo
Args: Args:
resource: The resource that will handle the upload request resource: The resource that will handle the upload request
@ -513,7 +534,7 @@ class RestHelper:
channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
# expect a confirmation page # expect a confirmation page
assert channel.code == 200, channel.result assert channel.code == HTTPStatus.OK, channel.result
# fish the matrix login token out of the body of the confirmation page # fish the matrix login token out of the body of the confirmation page
m = re.search( m = re.search(
@ -532,7 +553,7 @@ class RestHelper:
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
assert channel.code == 200 assert channel.code == HTTPStatus.OK
return channel.json_body return channel.json_body
def auth_via_oidc( def auth_via_oidc(
@ -641,7 +662,7 @@ class RestHelper:
(expected_uri, resp_obj) = expected_requests.pop(0) (expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri assert uri == expected_uri
resp = FakeResponse( resp = FakeResponse(
code=200, code=HTTPStatus.OK,
phrase=b"OK", phrase=b"OK",
body=json.dumps(resp_obj).encode("utf-8"), body=json.dumps(resp_obj).encode("utf-8"),
) )
@ -739,7 +760,7 @@ class RestHelper:
self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
) )
# that should serve a confirmation page # that should serve a confirmation page
assert channel.code == 200, channel.text_body assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
# parse the confirmation page to fish out the link. # parse the confirmation page to fish out the link.