0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-09-24 18:49:01 +02:00

Make _make_callback_with_userinfo async

... so that we can test its behaviour when it raises.

Also pull it out to the top level so that I can use it from other test classes.
This commit is contained in:
Richard van der Hoff 2020-12-15 13:03:31 +00:00
parent c1883f042d
commit 8388a7fb3a

View file

@ -21,6 +21,7 @@ import pymacaroons
from synapse.handlers.oidc_handler import OidcError from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock from tests.test_utils import FakeResponse, simple_async_mock
@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request( request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address code, state, session, user_agent=user_agent, ip_address=ip_address
) )
@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request("code", state, session) request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test_user", "sub": "test_user",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, None, "@test_user:test", ANY, ANY, None,
) )
@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": 1234, "sub": 1234,
"username": "test_user_2", "username": "test_user_2",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, None, "@test_user_2:test", ANY, ANY, None,
) )
@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None) store.register_user(user_id=user3.to_string(), password_hash=None)
) )
userinfo = {"sub": "test3", "username": "test_user_3"} userinfo = {"sub": "test3", "username": "test_user_3"}
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError( self.assertRenderedError(
"mapping_error", "mapping_error",
@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, user.to_string(), ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, user.to_string(), ANY, ANY, None,
) )
@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1", "sub": "test1",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, user.to_string(), ANY, ANY, None,
) )
@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2", "sub": "test2",
"username": "TEST_USER_2", "username": "TEST_USER_2",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error") args = self.assertRenderedError("mapping_error")
self.assertTrue( self.assertTrue(
@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None) store.register_user(user_id=user2.to_string(), password_hash=None)
) )
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, None, "@TEST_USER_2:test", ANY, ANY, None,
) )
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
self.assertRenderedError("mapping_error", "localpart is invalid: föö") self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config( @override_config(
@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
@ -784,32 +787,44 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester", "sub": "tester",
"username": "tester", "username": "tester",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError( self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response" "mapping_error", "Unable to generate a Matrix ID from the SSO response"
) )
def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect" async def _make_callback_with_userinfo(
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None: ) -> None:
self.handler._exchange_code = simple_async_mock(return_value={}) """Mock up an OIDC callback with the given userinfo dict
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
Args:
hs: the HomeServer impl to send the callback to.
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state" state = "state"
session = self.handler._generate_oidc_session_token( session = handler._generate_oidc_session_token(
state=state, state=state,
nonce="nonce", nonce="nonce",
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request("code", state, session) request = _build_callback_request("code", state, session)
await handler.handle_oidc_callback(request)
self.get_success(self.handler.handle_oidc_callback(request))
def _build_callback_request( def _build_callback_request(
self,
code: str, code: str,
state: str, state: str,
session: str, session: str,