forked from MirrorHub/synapse
Use HTTPStatus constants in place of literals in tests. (#13297)
This commit is contained in:
parent
7b67e93d49
commit
96cf81e312
9 changed files with 308 additions and 238 deletions
1
changelog.d/13297.misc
Normal file
1
changelog.d/13297.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Use `HTTPStatus` constants in place of literals in tests.
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
channel = self.make_signed_federation_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(HTTPStatus.OK, channel.code)
|
||||||
complexity = channel.json_body["v1"]
|
complexity = channel.json_body["v1"]
|
||||||
self.assertTrue(complexity > 0, complexity)
|
self.assertTrue(complexity > 0, complexity)
|
||||||
|
|
||||||
|
@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
channel = self.make_signed_federation_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code)
|
self.assertEqual(HTTPStatus.OK, channel.code)
|
||||||
complexity = channel.json_body["v1"]
|
complexity = channel.json_body["v1"]
|
||||||
self.assertEqual(complexity, 1.23)
|
self.assertEqual(complexity, 1.23)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
|
||||||
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
|
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
|
||||||
query_content,
|
query_content,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
|
||||||
channel = self.make_signed_federation_request(
|
channel = self.make_signed_federation_request(
|
||||||
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
|
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
|
||||||
)
|
)
|
||||||
self.assertEqual(403, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,7 +154,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
|
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
|
||||||
f"?ver={DEFAULT_ROOM_VERSION}",
|
f"?ver={DEFAULT_ROOM_VERSION}",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
def test_send_join(self):
|
def test_send_join(self):
|
||||||
|
@ -171,7 +172,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
|
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
|
||||||
content=join_event_dict,
|
content=join_event_dict,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
# we should get complete room state back
|
# we should get complete room state back
|
||||||
returned_state = [
|
returned_state = [
|
||||||
|
@ -226,7 +227,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||||
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
|
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
|
||||||
content=join_event_dict,
|
content=join_event_dict,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
# expect a reduced room state
|
# expect a reduced room state
|
||||||
returned_state = [
|
returned_state = [
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
|
@ -255,7 +256,7 @@ class FederationKnockingTestCase(
|
||||||
RoomVersions.V7.identifier,
|
RoomVersions.V7.identifier,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
# Note: We don't expect the knock membership event to be sent over federation as
|
# Note: We don't expect the knock membership event to be sent over federation as
|
||||||
# part of the stripped room state, as the knocking homeserver already has that
|
# part of the stripped room state, as the knocking homeserver already has that
|
||||||
|
@ -293,7 +294,7 @@ class FederationKnockingTestCase(
|
||||||
% (room_id, signed_knock_event.event_id),
|
% (room_id, signed_knock_event.event_id),
|
||||||
signed_knock_event_json,
|
signed_knock_event_json,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
# Check that we got the stripped room state in return
|
# Check that we got the stripped room state in return
|
||||||
room_state_events = channel.json_body["knock_state_events"]
|
room_state_events = channel.json_body["knock_state_events"]
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
"""Tests for the password_auth_provider interface"""
|
"""Tests for the password_auth_provider interface"""
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Type, Union
|
from typing import Any, Type, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# check_password must return an awaitable
|
# check_password must return an awaitable
|
||||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||||
channel = self._send_password_login("u", "p")
|
channel = self._send_password_login("u", "p")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
|
||||||
# login with mxid should work too
|
# login with mxid should work too
|
||||||
channel = self._send_password_login("@u:bz", "p")
|
channel = self._send_password_login("@u:bz", "p")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
||||||
mock_password_provider.reset_mock()
|
mock_password_provider.reset_mock()
|
||||||
|
@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# in these cases, but at least we can guard against the API changing
|
# in these cases, but at least we can guard against the API changing
|
||||||
# unexpectedly
|
# unexpectedly
|
||||||
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_password.assert_called_once_with(
|
mock_password_provider.check_password.assert_called_once_with(
|
||||||
"@ USER🙂NAME :test", " pASS😢word "
|
"@ USER🙂NAME :test", " pASS😢word "
|
||||||
|
@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# check_password must return an awaitable
|
# check_password must return an awaitable
|
||||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||||
channel = self._send_password_login("u", "p")
|
channel = self._send_password_login("u", "p")
|
||||||
self.assertEqual(channel.code, 403, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||||
|
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||||
|
@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
channel = self._send_password_login("u", "p")
|
channel = self._send_password_login("u", "p")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||||
|
@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# login with missing param should be rejected
|
# login with missing param should be rejected
|
||||||
channel = self._send_login("test.login_type", "u")
|
channel = self._send_login("test.login_type", "u")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||||
("@user:bz", None)
|
("@user:bz", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_auth.assert_called_once_with(
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
"u", "test.login_type", {"test_field": "y"}
|
"u", "test.login_type", {"test_field": "y"}
|
||||||
|
@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
("@ MALFORMED! :bz", None)
|
("@ MALFORMED! :bz", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_auth.assert_called_once_with(
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||||
|
@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
("@user:bz", callback)
|
("@user:bz", callback)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||||
mock_password_provider.check_auth.assert_called_once_with(
|
mock_password_provider.check_auth.assert_called_once_with(
|
||||||
"u", "test.login_type", {"test_field": "y"}
|
"u", "test.login_type", {"test_field": "y"}
|
||||||
|
@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
mock_password_provider.check_auth.assert_not_called()
|
mock_password_provider.check_auth.assert_not_called()
|
||||||
mock_password_provider.check_password.assert_not_called()
|
mock_password_provider.check_password.assert_not_called()
|
||||||
|
|
||||||
|
@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
("@localuser:test", None)
|
("@localuser:test", None)
|
||||||
)
|
)
|
||||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
tok1 = channel.json_body["access_token"]
|
tok1 = channel.json_body["access_token"]
|
||||||
|
|
||||||
channel = self._send_login(
|
channel = self._send_login(
|
||||||
"test.login_type", "localuser", test_field="", device_id="dev2"
|
"test.login_type", "localuser", test_field="", device_id="dev2"
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
# make the initial request which returns a 401
|
# make the initial request which returns a 401
|
||||||
channel = self._delete_device(tok1, "dev2")
|
channel = self._delete_device(tok1, "dev2")
|
||||||
|
@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
# password login shouldn't work and should be rejected with a 400
|
# password login shouldn't work and should be rejected with a 400
|
||||||
# ("unknown login type")
|
# ("unknown login type")
|
||||||
channel = self._send_password_login("localuser", "localpass")
|
channel = self._send_password_login("localuser", "localpass")
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
|
|
||||||
def test_on_logged_out(self):
|
def test_on_logged_out(self):
|
||||||
"""Tests that the on_logged_out callback is called when the user logs out."""
|
"""Tests that the on_logged_out callback is called when the user logs out."""
|
||||||
|
@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
access_token=tok,
|
access_token=tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 403, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["errcode"],
|
channel.json_body["errcode"],
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
|
@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
access_token=tok,
|
access_token=tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
self.assertIn("sid", channel.json_body)
|
self.assertIn("sid", channel.json_body)
|
||||||
|
|
||||||
m.assert_called_once_with("email", "bar@test.com", registration)
|
m.assert_called_once_with("email", "bar@test.com", registration)
|
||||||
|
@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
"register",
|
"register",
|
||||||
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
def _get_login_flows(self) -> JsonDict:
|
def _get_login_flows(self) -> JsonDict:
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
return channel.json_body["flows"]
|
return channel.json_body["flows"]
|
||||||
|
|
||||||
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||||
|
|
|
@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content=body,
|
content=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
|
@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content=body,
|
content=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
|
@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content={"password": "abc123", "admin": False},
|
content={"password": "abc123", "admin": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertFalse(channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
|
|
||||||
|
@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Admin user is not blocked by mau anymore
|
# Admin user is not blocked by mau anymore
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertFalse(channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
|
|
||||||
|
@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content=body,
|
content=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content=body,
|
content=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content=body,
|
content=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
|
||||||
|
@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
content={"password": "abc123"},
|
content={"password": "abc123"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual("bob", channel.json_body["displayname"])
|
self.assertEqual("bob", channel.json_body["displayname"])
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from email.parser import Parser
|
from email.parser import Parser
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
@ -98,7 +99,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
|
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 403, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||||
|
|
||||||
def test_basic_password_reset(self) -> None:
|
def test_basic_password_reset(self) -> None:
|
||||||
"""Test basic password reset flow"""
|
"""Test basic password reset flow"""
|
||||||
|
@ -347,7 +348,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
|
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
|
||||||
# password reset confirm button
|
# password reset confirm button
|
||||||
|
@ -362,7 +363,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
content_is_form=True,
|
content_is_form=True,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
def _get_link_from_email(self) -> str:
|
def _get_link_from_email(self) -> str:
|
||||||
assert self.email_attempts, "No emails have been sent"
|
assert self.email_attempts, "No emails have been sent"
|
||||||
|
@ -390,7 +391,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||||
new_password: str,
|
new_password: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
client_secret: str,
|
client_secret: str,
|
||||||
expected_code: int = 200,
|
expected_code: int = HTTPStatus.OK,
|
||||||
) -> None:
|
) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
|
@ -715,7 +716,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(
|
||||||
|
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||||
|
)
|
||||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -725,7 +728,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_delete_email(self) -> None:
|
def test_delete_email(self) -> None:
|
||||||
|
@ -747,7 +750,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
{"medium": "email", "address": self.email},
|
{"medium": "email", "address": self.email},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -756,7 +759,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_delete_email_if_disabled(self) -> None:
|
def test_delete_email_if_disabled(self) -> None:
|
||||||
|
@ -781,7 +784,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(
|
||||||
|
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||||
|
)
|
||||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -791,7 +796,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
||||||
|
|
||||||
|
@ -817,7 +822,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(
|
||||||
|
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||||
|
)
|
||||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -827,7 +834,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
def test_no_valid_token(self) -> None:
|
def test_no_valid_token(self) -> None:
|
||||||
|
@ -852,7 +859,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(
|
||||||
|
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||||
|
)
|
||||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -862,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertFalse(channel.json_body["threepids"])
|
self.assertFalse(channel.json_body["threepids"])
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
|
@ -872,7 +881,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="https://example.com/a/good/site",
|
next_link="https://example.com/a/good/site",
|
||||||
expect_code=200,
|
expect_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
|
@ -884,7 +893,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
|
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
|
||||||
expect_code=200,
|
expect_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": None})
|
@override_config({"next_link_domain_whitelist": None})
|
||||||
|
@ -895,7 +904,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="file:///host/path",
|
next_link="file:///host/path",
|
||||||
expect_code=400,
|
expect_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
||||||
|
@ -907,28 +916,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link=None,
|
next_link=None,
|
||||||
expect_code=200,
|
expect_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._request_token(
|
self._request_token(
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="https://example.com/some/good/page",
|
next_link="https://example.com/some/good/page",
|
||||||
expect_code=200,
|
expect_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._request_token(
|
self._request_token(
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="https://example.org/some/also/good/page",
|
next_link="https://example.org/some/also/good/page",
|
||||||
expect_code=200,
|
expect_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._request_token(
|
self._request_token(
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="https://bad.example.org/some/bad/page",
|
next_link="https://bad.example.org/some/bad/page",
|
||||||
expect_code=400,
|
expect_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"next_link_domain_whitelist": []})
|
@override_config({"next_link_domain_whitelist": []})
|
||||||
|
@ -940,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
"something@example.com",
|
"something@example.com",
|
||||||
"some_secret",
|
"some_secret",
|
||||||
next_link="https://example.com/a/page",
|
next_link="https://example.com/a/page",
|
||||||
expect_code=400,
|
expect_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _request_token(
|
def _request_token(
|
||||||
|
@ -948,7 +957,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
email: str,
|
email: str,
|
||||||
client_secret: str,
|
client_secret: str,
|
||||||
next_link: Optional[str] = None,
|
next_link: Optional[str] = None,
|
||||||
expect_code: int = 200,
|
expect_code: int = HTTPStatus.OK,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Request a validation token to add an email address to a user's account
|
"""Request a validation token to add an email address to a user's account
|
||||||
|
|
||||||
|
@ -993,7 +1002,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
b"account/3pid/email/requestToken",
|
b"account/3pid/email/requestToken",
|
||||||
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
||||||
)
|
)
|
||||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
self.assertEqual(
|
||||||
|
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||||
|
)
|
||||||
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
||||||
self.assertEqual(expected_error, channel.json_body["error"])
|
self.assertEqual(expected_error, channel.json_body["error"])
|
||||||
|
|
||||||
|
@ -1002,7 +1013,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
path = link.replace("https://example.com", "")
|
path = link.replace("https://example.com", "")
|
||||||
|
|
||||||
channel = self.make_request("GET", path, shorthand=False)
|
channel = self.make_request("GET", path, shorthand=False)
|
||||||
self.assertEqual(200, channel.code, channel.result)
|
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||||
|
|
||||||
def _get_link_from_email(self) -> str:
|
def _get_link_from_email(self) -> str:
|
||||||
assert self.email_attempts, "No emails have been sent"
|
assert self.email_attempts, "No emails have been sent"
|
||||||
|
@ -1052,7 +1063,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -1061,7 +1072,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=self.user_id_tok,
|
access_token=self.user_id_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
|
|
||||||
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
||||||
|
@ -1092,7 +1103,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests that not providing any MXID raises an error."""
|
"""Tests that not providing any MXID raises an error."""
|
||||||
self._test_status(
|
self._test_status(
|
||||||
users=None,
|
users=None,
|
||||||
expected_status_code=400,
|
expected_status_code=HTTPStatus.BAD_REQUEST,
|
||||||
expected_errcode=Codes.MISSING_PARAM,
|
expected_errcode=Codes.MISSING_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1100,7 +1111,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests that providing an invalid MXID raises an error."""
|
"""Tests that providing an invalid MXID raises an error."""
|
||||||
self._test_status(
|
self._test_status(
|
||||||
users=["bad:test"],
|
users=["bad:test"],
|
||||||
expected_status_code=400,
|
expected_status_code=HTTPStatus.BAD_REQUEST,
|
||||||
expected_errcode=Codes.INVALID_PARAM,
|
expected_errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1286,7 +1297,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||||
def _test_status(
|
def _test_status(
|
||||||
self,
|
self,
|
||||||
users: Optional[List[str]],
|
users: Optional[List[str]],
|
||||||
expected_status_code: int = 200,
|
expected_status_code: int = HTTPStatus.OK,
|
||||||
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
||||||
expected_failures: Optional[List[str]] = None,
|
expected_failures: Optional[List[str]] = None,
|
||||||
expected_errcode: Optional[str] = None,
|
expected_errcode: Optional[str] = None,
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
@ -261,20 +262,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
access_token = channel.json_body["access_token"]
|
access_token = channel.json_body["access_token"]
|
||||||
device_id = channel.json_body["device_id"]
|
device_id = channel.json_body["device_id"]
|
||||||
|
|
||||||
# we should now be able to make requests with the access token
|
# we should now be able to make requests with the access token
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
# time passes
|
# time passes
|
||||||
self.reactor.advance(24 * 3600)
|
self.reactor.advance(24 * 3600)
|
||||||
|
|
||||||
# ... and we should be soft-logouted
|
# ... and we should be soft-logouted
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||||
|
|
||||||
|
@ -288,7 +289,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
# more requests with the expired token should still return a soft-logout
|
# more requests with the expired token should still return a soft-logout
|
||||||
self.reactor.advance(3600)
|
self.reactor.advance(3600)
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||||
|
|
||||||
|
@ -296,7 +297,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
||||||
|
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEqual(channel.json_body["soft_logout"], False)
|
self.assertEqual(channel.json_body["soft_logout"], False)
|
||||||
|
|
||||||
|
@ -307,7 +308,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
b"DELETE", "devices/" + device_id, access_token=access_token
|
b"DELETE", "devices/" + device_id, access_token=access_token
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
# check it's a UI-Auth fail
|
# check it's a UI-Auth fail
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(channel.json_body.keys()),
|
set(channel.json_body.keys()),
|
||||||
|
@ -330,7 +331,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
content={"auth": auth},
|
content={"auth": auth},
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
@override_config({"session_lifetime": "24h"})
|
@override_config({"session_lifetime": "24h"})
|
||||||
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
||||||
|
@ -341,14 +342,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# we should now be able to make requests with the access token
|
# we should now be able to make requests with the access token
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
# time passes
|
# time passes
|
||||||
self.reactor.advance(24 * 3600)
|
self.reactor.advance(24 * 3600)
|
||||||
|
|
||||||
# ... and we should be soft-logouted
|
# ... and we should be soft-logouted
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||||
|
|
||||||
|
@ -367,14 +368,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# we should now be able to make requests with the access token
|
# we should now be able to make requests with the access token
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
# time passes
|
# time passes
|
||||||
self.reactor.advance(24 * 3600)
|
self.reactor.advance(24 * 3600)
|
||||||
|
|
||||||
# ... and we should be soft-logouted
|
# ... and we should be soft-logouted
|
||||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||||
self.assertEqual(channel.code, 401, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||||
|
|
||||||
|
@ -466,7 +467,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
def test_get_login_flows(self) -> None:
|
def test_get_login_flows(self) -> None:
|
||||||
"""GET /login should return password and SSO flows"""
|
"""GET /login should return password and SSO flows"""
|
||||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
expected_flow_types = [
|
expected_flow_types = [
|
||||||
"m.login.cas",
|
"m.login.cas",
|
||||||
|
@ -494,14 +495,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
"""/login/sso/redirect should redirect to an identity picker"""
|
"""/login/sso/redirect should redirect to an identity picker"""
|
||||||
# first hit the redirect url, which should redirect to our idp picker
|
# first hit the redirect url, which should redirect to our idp picker
|
||||||
channel = self._make_sso_redirect_request(None)
|
channel = self._make_sso_redirect_request(None)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
uri = location_headers[0]
|
uri = location_headers[0]
|
||||||
|
|
||||||
# hitting that picker should give us some HTML
|
# hitting that picker should give us some HTML
|
||||||
channel = self.make_request("GET", uri)
|
channel = self.make_request("GET", uri)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
# parse the form to check it has fields assumed elsewhere in this class
|
# parse the form to check it has fields assumed elsewhere in this class
|
||||||
html = channel.result["body"].decode("utf-8")
|
html = channel.result["body"].decode("utf-8")
|
||||||
|
@ -530,7 +531,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ "&idp=cas",
|
+ "&idp=cas",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
cas_uri = location_headers[0]
|
cas_uri = location_headers[0]
|
||||||
|
@ -555,7 +556,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||||
+ "&idp=saml",
|
+ "&idp=saml",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
saml_uri = location_headers[0]
|
saml_uri = location_headers[0]
|
||||||
|
@ -579,7 +580,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||||
+ "&idp=oidc",
|
+ "&idp=oidc",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
oidc_uri = location_headers[0]
|
oidc_uri = location_headers[0]
|
||||||
|
@ -606,7 +607,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
|
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
|
||||||
|
|
||||||
# that should serve a confirmation page
|
# that should serve a confirmation page
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
||||||
assert content_type_headers
|
assert content_type_headers
|
||||||
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
||||||
|
@ -634,7 +635,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
"/login",
|
"/login",
|
||||||
content={"type": "m.login.token", "token": login_token},
|
content={"type": "m.login.token", "token": login_token},
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
||||||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||||
|
|
||||||
def test_multi_sso_redirect_to_unknown(self) -> None:
|
def test_multi_sso_redirect_to_unknown(self) -> None:
|
||||||
|
@ -643,18 +644,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
"GET",
|
"GET",
|
||||||
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
|
|
||||||
def test_client_idp_redirect_to_unknown(self) -> None:
|
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||||
channel = self._make_sso_redirect_request("xxx")
|
channel = self._make_sso_redirect_request("xxx")
|
||||||
self.assertEqual(channel.code, 404, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
def test_client_idp_redirect_to_oidc(self) -> None:
|
def test_client_idp_redirect_to_oidc(self) -> None:
|
||||||
"""If the client pick a known IdP, redirect to it"""
|
"""If the client pick a known IdP, redirect to it"""
|
||||||
channel = self._make_sso_redirect_request("oidc")
|
channel = self._make_sso_redirect_request("oidc")
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
oidc_uri = location_headers[0]
|
oidc_uri = location_headers[0]
|
||||||
|
@ -765,7 +766,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
channel = self.make_request("GET", cas_ticket_url)
|
channel = self.make_request("GET", cas_ticket_url)
|
||||||
|
|
||||||
# Test that the response is HTML.
|
# Test that the response is HTML.
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
content_type_header_value = ""
|
content_type_header_value = ""
|
||||||
for header in channel.result.get("headers", []):
|
for header in channel.result.get("headers", []):
|
||||||
if header[0] == b"Content-Type":
|
if header[0] == b"Content-Type":
|
||||||
|
@ -1246,7 +1247,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# that should redirect to the username picker
|
# that should redirect to the username picker
|
||||||
self.assertEqual(channel.code, 302, channel.result)
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
location_headers = channel.headers.getRawHeaders("Location")
|
location_headers = channel.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
picker_url = location_headers[0]
|
picker_url = location_headers[0]
|
||||||
|
@ -1290,7 +1291,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
("Content-Length", str(len(content))),
|
("Content-Length", str(len(content))),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 302, chan.result)
|
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
||||||
location_headers = chan.headers.getRawHeaders("Location")
|
location_headers = chan.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
|
|
||||||
|
@ -1300,7 +1301,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
path=location_headers[0],
|
path=location_headers[0],
|
||||||
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 302, chan.result)
|
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
||||||
location_headers = chan.headers.getRawHeaders("Location")
|
location_headers = chan.headers.getRawHeaders("Location")
|
||||||
assert location_headers
|
assert location_headers
|
||||||
|
|
||||||
|
@ -1325,5 +1326,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||||
"/login",
|
"/login",
|
||||||
content={"type": "m.login.token", "token": login_token},
|
content={"type": "m.login.token", "token": login_token},
|
||||||
)
|
)
|
||||||
self.assertEqual(chan.code, 200, chan.result)
|
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
||||||
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue