0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 04:52:26 +01:00

Merge pull request #8951 from matrix-org/rav/username_picker_2

More preparatory refactoring of the OidcHandler tests
This commit is contained in:
Richard van der Hoff 2020-12-16 14:53:26 +00:00 committed by GitHub
commit 7a332850e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 80 deletions

1
changelog.d/8951.feature Normal file
View file

@ -0,0 +1 @@
Add support for allowing users to pick their own user ID during a single-sign-on login.

View file

@ -19,8 +19,9 @@ from mock import ANY, Mock, patch
import pymacaroons import pymacaroons
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock from tests.test_utils import FakeResponse, simple_async_mock
@ -55,11 +56,14 @@ COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc" COOKIE_PATH = "/_synapse/oidc"
class TestMappingProvider(OidcMappingProvider): class TestMappingProvider:
@staticmethod @staticmethod
def parse_config(config): def parse_config(config):
return return
def __init__(self, config):
pass
def get_remote_user_id(self, userinfo): def get_remote_user_id(self, userinfo):
return userinfo["sub"] return userinfo["sub"]
@ -360,6 +364,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails - when the userinfo fetching fails
- when the code exchange fails - when the code exchange fails
""" """
# ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented.
mapping_provider = self.handler._user_mapping_provider
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
token = { token = {
"type": "bearer", "type": "bearer",
"id_token": "id_token", "id_token": "id_token",
@ -389,14 +400,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request( request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address code, state, session, user_agent=user_agent, ip_address=ip_address
) )
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, {}, expected_user_id, request, client_redirect_url, None,
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@ -427,7 +438,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, {}, expected_user_id, request, client_redirect_url, None,
) )
self.handler._exchange_code.assert_called_once_with(code) self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called() self.handler._parse_id_token.assert_not_called()
@ -597,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url, client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ui_auth_session_id=None,
) )
request = self._build_callback_request("code", state, session) request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -614,9 +625,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test_user", "sub": "test_user",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, {} "@test_user:test", ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -625,9 +636,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": 1234, "sub": 1234,
"username": "test_user_2", "username": "test_user_2",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, {} "@test_user_2:test", ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -638,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None) store.register_user(user_id=user3.to_string(), password_hash=None)
) )
userinfo = {"sub": "test3", "username": "test_user_3"} userinfo = {"sub": "test3", "username": "test_user_3"}
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError( self.assertRenderedError(
"mapping_error", "mapping_error",
@ -662,16 +673,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {}, user.to_string(), ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {}, user.to_string(), ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -684,9 +695,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1", "sub": "test1",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, {}, user.to_string(), ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -705,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2", "sub": "test2",
"username": "TEST_USER_2", "username": "TEST_USER_2",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error") args = self.assertRenderedError("mapping_error")
self.assertTrue( self.assertTrue(
@ -720,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None) store.register_user(user_id=user2.to_string(), password_hash=None)
) )
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, {}, "@TEST_USER_2:test", ANY, ANY, None,
) )
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
self.assertRenderedError("mapping_error", "localpart is invalid: föö") self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config( @override_config(
@ -752,11 +765,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test", "sub": "test",
"username": "test_user", "username": "test_user",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", ANY, ANY, {}, "@test_user1:test", ANY, ANY, None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -774,70 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester", "sub": "tester",
"username": "tester", "username": "tester",
} }
self._make_callback_with_userinfo(userinfo) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError( self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response" "mapping_error", "Unable to generate a Matrix ID from the SSO response"
) )
def _make_callback_with_userinfo(
self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
) -> None:
self.handler._exchange_code = simple_async_mock(return_value={})
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
state = "state" async def _make_callback_with_userinfo(
session = self.handler._generate_oidc_session_token( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
state=state, ) -> None:
nonce="nonce", """Mock up an OIDC callback with the given userinfo dict
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request)) We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
def _build_callback_request( Args:
self, hs: the HomeServer impl to send the callback to.
code: str, userinfo: the OIDC userinfo dict
state: str, client_redirect_url: the URL to redirect to on success.
session: str, """
user_agent: str = "Browser", handler = hs.get_oidc_handler()
ip_address: str = "10.0.0.1", handler._exchange_code = simple_async_mock(return_value={})
): handler._parse_id_token = simple_async_mock(return_value=userinfo)
"""Builds a fake SynapseRequest to mock the browser callback handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
Returns a Mock object which looks like the SynapseRequest we get from a browser state = "state"
after SSO (before we return to the client) session = handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = _build_callback_request("code", state, session)
Args: await handler.handle_oidc_callback(request)
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
request.getCookie.return_value = session
request.args = {} def _build_callback_request(
request.args[b"code"] = [code.encode("utf-8")] code: str,
request.args[b"state"] = [state.encode("utf-8")] state: str,
request.getClientIP.return_value = ip_address session: str,
request.get_user_agent.return_value = user_agent user_agent: str = "Browser",
return request ip_address: str = "10.0.0.1",
):
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
after SSO (before we return to the client)
Args:
code: the authorization code which would have been returned by the OIDC
provider
state: the "state" param which would have been passed around in the
query param. Should be the same as was embedded in the session in
_build_oidc_session.
session: the "session" which would have been passed around in the cookie.
user_agent: the user-agent to present
ip_address: the IP address to pretend the request came from
"""
request = Mock(
spec=[
"args",
"getCookie",
"addCookie",
"requestHeaders",
"getClientIP",
"get_user_agent",
]
)
request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.get_user_agent.return_value = user_agent
return request

View file

@ -443,7 +443,7 @@ class RestHelper:
return channel.json_body return channel.json_body
# an 'oidc_config' suitable for login_with_oidc. # an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_CONFIG = { TEST_OIDC_CONFIG = {
"enabled": True, "enabled": True,
"discover": False, "discover": False,