0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-08 13:48:57 +02:00

Fix bugs in handling clientRedirectUrl, and improve OIDC tests (#9127, #9128)

* Factor out a common TestHtmlParser

Looks like I'm doing this in a few different places.

* Improve OIDC login test

Complete the OIDC login flow, rather than giving up halfway through.

* Ensure that OIDC login works with multiple OIDC providers

* Fix bugs in handling clientRedirectUrl

 - don't drop duplicate query-params, or params with no value
 - allow utf-8 in query-params
This commit is contained in:
Richard van der Hoff 2021-01-18 14:52:49 +00:00 committed by GitHub
parent a8703819eb
commit 02070c69fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 189 additions and 86 deletions

1
changelog.d/9127.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9128.bugfix Normal file
View file

@ -0,0 +1 @@
Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login.

View file

@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({param_name: param})
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)

View file

@ -85,7 +85,7 @@ class OidcHandler:
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
}
} # type: Dict[str, OidcProvider]
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.

View file

@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
self._server_name = hs.hostname
async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True)
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding="utf-8"
)
idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so

View file

@ -15,9 +15,8 @@
import time
import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse
from typing import Any, Dict, Union
from urllib.parse import urlencode
from mock import Mock
@ -38,6 +37,7 @@ from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
@ -69,6 +69,12 @@ TEST_SAML_METADATA = """
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
},
}
# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG
# additional OIDC providers
config["oidc_providers"] = [
{
"idp_id": "idp1",
"idp_name": "IDP1",
"discover": False,
"issuer": "https://issuer1",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://issuer1/auth",
"token_endpoint": "https://issuer1/token",
"userinfo_endpoint": "https://issuer1/userinfo",
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.sub }}"}
},
}
]
return config
def create_resource_dict(self) -> Dict[str, Resource]:
from synapse.rest.oidc import OIDCResource
d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"
# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
"/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]
@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()
# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
self.fail(message)
p = FormPageParser()
p = TestHtmlParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()
self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])
self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
"/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)
self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_oidc(self):
def test_login_via_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"
# pick the default OIDC provider
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url,
TEST_CLIENT_REDIRECT_URL,
)
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
)
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
# ... which should contain our redirect link
self.assertEqual(len(p.links), 1)
path, query = p.links[0].split("?", 1)
self.assertEqual(path, "https://x")
# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
"POST", "/login", content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# whitelist this client URI so we redirect straight to it rather than
# serving a confirmation page
config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
config["sso"] = {"client_whitelist": ["https://x"]}
return config
def create_resource_dict(self) -> Dict[str, Resource]:
@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self):
"""Test the happy path of a username picker flow."""
client_redirect_url = "https://whitelisted.client"
# do the start of the login flow
channel = self.helper.auth_via_oidc(
{"sub": "tester", "displayname": "Jonny"}, client_redirect_url
{"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
)
# that should redirect to the username picker
@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, "Jonny")
self.assertEqual(session.client_redirect_url, client_redirect_url)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
# ensure that the returned location starts with the requested redirect URL
self.assertEqual(
location_headers[0][: len(client_redirect_url)], client_redirect_url
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")
# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")
# fish the login token out of the returned redirect uri
parts = urlparse(location_headers[0])
query = parse_qs(parts.query)
login_token = query["loginToken"][0]
login_token = params[2][1]
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.

View file

@ -20,8 +20,7 @@ import json
import re
import time
import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
from typing import Any, Dict, Mapping, MutableMapping, Optional
from mock import patch
@ -35,6 +34,7 @@ from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
@attr.s
@ -440,10 +440,36 @@ class RestHelper:
# param that synapse passes to the IdP via query params, as well as the cookie
# that synapse passes to the client.
oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
oauth_uri_path, _ = oauth_uri.split("?", 1)
assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
"unexpected SSO URI " + oauth_uri_path
)
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
def complete_oidc_auth(
self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
sent back to the OIDC provider.
Requires the OIDC callback resource to be mounted at the normal place.
Args:
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
@ -456,9 +482,9 @@ class RestHelper:
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}),
(TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", user_info_dict),
(TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
@ -542,25 +568,7 @@ class RestHelper:
channel.extract_cookies(cookies)
# parse the confirmation page to fish out the link.
class ConfirmationPageParser(HTMLParser):
def __init__(self):
super().__init__()
self.links = [] # type: List[str]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)
def error(_, message):
raise AssertionError(message)
p = ConfirmationPageParser()
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
assert len(p.links) == 1, "not exactly one link in confirmation page"
@ -570,6 +578,8 @@ class RestHelper:
# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = {
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
"userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}

View file

@ -74,7 +74,7 @@ class FakeChannel:
return int(self.result["code"])
@property
def headers(self):
def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
h = Headers()

View file

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from html.parser import HTMLParser
from typing import Dict, Iterable, List, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
def __init__(self):
super().__init__()
# a list of links found in the doc
self.links = [] # type: List[str]
# the values of any hidden <input>s: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of any radio buttons: map from name to list of values
self.radios = {} # type: Dict[str, List[Optional[str]]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)
elif tag == "input":
input_name = attr_dict.get("name")
if attr_dict["type"] == "radio":
assert input_name
self.radios.setdefault(input_name, []).append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
raise AssertionError(message)