mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-12 04:52:26 +01:00
Make _make_callback_with_userinfo
async
... so that we can test its behaviour when it raises. Also pull it out to the top level so that I can use it from other test classes.
This commit is contained in:
parent
c1883f042d
commit
8388a7fb3a
1 changed files with 81 additions and 66 deletions
|
@ -21,6 +21,7 @@ import pymacaroons
|
||||||
|
|
||||||
from synapse.handlers.oidc_handler import OidcError
|
from synapse.handlers.oidc_handler import OidcError
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, simple_async_mock
|
from tests.test_utils import FakeResponse, simple_async_mock
|
||||||
|
@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
ui_auth_session_id=None,
|
ui_auth_session_id=None,
|
||||||
)
|
)
|
||||||
request = self._build_callback_request(
|
request = _build_callback_request(
|
||||||
code, state, session, user_agent=user_agent, ip_address=ip_address
|
code, state, session, user_agent=user_agent, ip_address=ip_address
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
client_redirect_url=client_redirect_url,
|
client_redirect_url=client_redirect_url,
|
||||||
ui_auth_session_id=None,
|
ui_auth_session_id=None,
|
||||||
)
|
)
|
||||||
request = self._build_callback_request("code", state, session)
|
request = _build_callback_request("code", state, session)
|
||||||
|
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
|
@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "test_user",
|
"sub": "test_user",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", ANY, ANY, None,
|
"@test_user:test", ANY, ANY, None,
|
||||||
)
|
)
|
||||||
|
@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": 1234,
|
"sub": 1234,
|
||||||
"username": "test_user_2",
|
"username": "test_user_2",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user_2:test", ANY, ANY, None,
|
"@test_user_2:test", ANY, ANY, None,
|
||||||
)
|
)
|
||||||
|
@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
store.register_user(user_id=user3.to_string(), password_hash=None)
|
store.register_user(user_id=user3.to_string(), password_hash=None)
|
||||||
)
|
)
|
||||||
userinfo = {"sub": "test3", "username": "test_user_3"}
|
userinfo = {"sub": "test3", "username": "test_user_3"}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
self.assertRenderedError(
|
self.assertRenderedError(
|
||||||
"mapping_error",
|
"mapping_error",
|
||||||
|
@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "test",
|
"sub": "test",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None,
|
user.to_string(), ANY, ANY, None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None,
|
user.to_string(), ANY, ANY, None,
|
||||||
)
|
)
|
||||||
|
@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "test1",
|
"sub": "test1",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), ANY, ANY, None,
|
user.to_string(), ANY, ANY, None,
|
||||||
)
|
)
|
||||||
|
@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "test2",
|
"sub": "test2",
|
||||||
"username": "TEST_USER_2",
|
"username": "TEST_USER_2",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
args = self.assertRenderedError("mapping_error")
|
args = self.assertRenderedError("mapping_error")
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
store.register_user(user_id=user2.to_string(), password_hash=None)
|
store.register_user(user_id=user2.to_string(), password_hash=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@TEST_USER_2:test", ANY, ANY, None,
|
"@TEST_USER_2:test", ANY, ANY, None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self):
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
|
self.get_success(
|
||||||
|
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
||||||
|
)
|
||||||
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
|
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "test",
|
"sub": "test",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
@ -784,68 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"sub": "tester",
|
"sub": "tester",
|
||||||
"username": "tester",
|
"username": "tester",
|
||||||
}
|
}
|
||||||
self._make_callback_with_userinfo(userinfo)
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_not_called()
|
auth_handler.complete_sso_login.assert_not_called()
|
||||||
self.assertRenderedError(
|
self.assertRenderedError(
|
||||||
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_callback_with_userinfo(
|
|
||||||
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
|
|
||||||
) -> None:
|
|
||||||
self.handler._exchange_code = simple_async_mock(return_value={})
|
|
||||||
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
||||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
||||||
|
|
||||||
state = "state"
|
async def _make_callback_with_userinfo(
|
||||||
session = self.handler._generate_oidc_session_token(
|
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
|
||||||
state=state,
|
) -> None:
|
||||||
nonce="nonce",
|
"""Mock up an OIDC callback with the given userinfo dict
|
||||||
client_redirect_url=client_redirect_url,
|
|
||||||
ui_auth_session_id=None,
|
|
||||||
)
|
|
||||||
request = self._build_callback_request("code", state, session)
|
|
||||||
|
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
|
||||||
|
and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
|
||||||
|
|
||||||
def _build_callback_request(
|
Args:
|
||||||
self,
|
hs: the HomeServer impl to send the callback to.
|
||||||
code: str,
|
userinfo: the OIDC userinfo dict
|
||||||
state: str,
|
client_redirect_url: the URL to redirect to on success.
|
||||||
session: str,
|
"""
|
||||||
user_agent: str = "Browser",
|
handler = hs.get_oidc_handler()
|
||||||
ip_address: str = "10.0.0.1",
|
handler._exchange_code = simple_async_mock(return_value={})
|
||||||
):
|
handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
"""Builds a fake SynapseRequest to mock the browser callback
|
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
|
|
||||||
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
state = "state"
|
||||||
after SSO (before we return to the client)
|
session = handler._generate_oidc_session_token(
|
||||||
|
state=state,
|
||||||
|
nonce="nonce",
|
||||||
|
client_redirect_url=client_redirect_url,
|
||||||
|
ui_auth_session_id=None,
|
||||||
|
)
|
||||||
|
request = _build_callback_request("code", state, session)
|
||||||
|
|
||||||
Args:
|
await handler.handle_oidc_callback(request)
|
||||||
code: the authorization code which would have been returned by the OIDC
|
|
||||||
provider
|
|
||||||
state: the "state" param which would have been passed around in the
|
|
||||||
query param. Should be the same as was embedded in the session in
|
|
||||||
_build_oidc_session.
|
|
||||||
session: the "session" which would have been passed around in the cookie.
|
|
||||||
user_agent: the user-agent to present
|
|
||||||
ip_address: the IP address to pretend the request came from
|
|
||||||
"""
|
|
||||||
request = Mock(
|
|
||||||
spec=[
|
|
||||||
"args",
|
|
||||||
"getCookie",
|
|
||||||
"addCookie",
|
|
||||||
"requestHeaders",
|
|
||||||
"getClientIP",
|
|
||||||
"get_user_agent",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
request.getCookie.return_value = session
|
|
||||||
request.args = {}
|
def _build_callback_request(
|
||||||
request.args[b"code"] = [code.encode("utf-8")]
|
code: str,
|
||||||
request.args[b"state"] = [state.encode("utf-8")]
|
state: str,
|
||||||
request.getClientIP.return_value = ip_address
|
session: str,
|
||||||
request.get_user_agent.return_value = user_agent
|
user_agent: str = "Browser",
|
||||||
return request
|
ip_address: str = "10.0.0.1",
|
||||||
|
):
|
||||||
|
"""Builds a fake SynapseRequest to mock the browser callback
|
||||||
|
|
||||||
|
Returns a Mock object which looks like the SynapseRequest we get from a browser
|
||||||
|
after SSO (before we return to the client)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: the authorization code which would have been returned by the OIDC
|
||||||
|
provider
|
||||||
|
state: the "state" param which would have been passed around in the
|
||||||
|
query param. Should be the same as was embedded in the session in
|
||||||
|
_build_oidc_session.
|
||||||
|
session: the "session" which would have been passed around in the cookie.
|
||||||
|
user_agent: the user-agent to present
|
||||||
|
ip_address: the IP address to pretend the request came from
|
||||||
|
"""
|
||||||
|
request = Mock(
|
||||||
|
spec=[
|
||||||
|
"args",
|
||||||
|
"getCookie",
|
||||||
|
"addCookie",
|
||||||
|
"requestHeaders",
|
||||||
|
"getClientIP",
|
||||||
|
"get_user_agent",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
request.getCookie.return_value = session
|
||||||
|
request.args = {}
|
||||||
|
request.args[b"code"] = [code.encode("utf-8")]
|
||||||
|
request.args[b"state"] = [state.encode("utf-8")]
|
||||||
|
request.getClientIP.return_value = ip_address
|
||||||
|
request.get_user_agent.return_value = user_agent
|
||||||
|
return request
|
||||||
|
|
Loading…
Reference in a new issue