Merge branch 'social_login' into develop

This commit is contained in:
Richard van der Hoff 2021-01-27 17:27:24 +00:00
commit 7fa1346f93
9 changed files with 230 additions and 21 deletions

1
changelog.d/9183.feature Normal file
View file

@ -0,0 +1 @@
Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858).

View file

@ -9,6 +9,7 @@ from synapse.config import (
consent_config, consent_config,
database, database,
emailconfig, emailconfig,
experimental,
groups, groups,
jwt_config, jwt_config,
key, key,
@ -49,6 +50,7 @@ def path_exists(file_path: str): ...
class RootConfig: class RootConfig:
server: server.ServerConfig server: server.ServerConfig
experimental: experimental.ExperimentalConfig
tls: tls.TlsConfig tls: tls.TlsConfig
database: database.DatabaseConfig database: database.DatabaseConfig
logging: logger.LoggingConfig logging: logger.LoggingConfig

View file

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config._base import Config
from synapse.types import JsonDict
class ExperimentalConfig(Config):
"""Config section for enabling experimental features"""
section = "experimental"
def read_config(self, config: JsonDict, **kwargs):
experimental = config.get("experimental_features") or {}
# MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool

View file

@ -24,6 +24,7 @@ from .cas import CasConfig
from .consent_config import ConsentConfig from .consent_config import ConsentConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .experimental import ExperimentalConfig
from .federation import FederationConfig from .federation import FederationConfig
from .groups import GroupsConfig from .groups import GroupsConfig
from .jwt_config import JWTConfig from .jwt_config import JWTConfig
@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
config_classes = [ config_classes = [
ServerConfig, ServerConfig,
ExperimentalConfig,
TlsConfig, TlsConfig,
FederationConfig, FederationConfig,
CacheConfig, CacheConfig,

View file

@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request from twisted.web.http import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
@ -235,7 +235,10 @@ class SsoHandler:
respond_with_html(request, code, html) respond_with_html(request, code, html)
async def handle_redirect_request( async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes, self,
request: SynapseRequest,
client_redirect_url: bytes,
idp_id: Optional[str],
) -> str: ) -> str:
"""Handle a request to /login/sso/redirect """Handle a request to /login/sso/redirect
@ -243,6 +246,7 @@ class SsoHandler:
request: incoming HTTP request request: incoming HTTP request
client_redirect_url: the URL that we should redirect the client_redirect_url: the URL that we should redirect the
client to after login. client to after login.
idp_id: optional identity provider chosen by the client
Returns: Returns:
the URI to redirect to the URI to redirect to
@ -252,10 +256,19 @@ class SsoHandler:
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
) )
# if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider]
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
raise NotFoundError("Unknown identity provider")
# if we only have one auth provider, redirect to it directly # if we only have one auth provider, redirect to it directly
if len(self._identity_providers) == 1: elif len(self._identity_providers) == 1:
ap = next(iter(self._identity_providers.values())) idp = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
if idp:
return await idp.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker # otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode( return "/_synapse/client/pick_idp?" + urlencode(

View file

@ -22,10 +22,22 @@ import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Pattern,
Tuple,
Union,
)
import jinja2 import jinja2
from canonicaljson import iterencode_canonical_json from canonicaljson import iterencode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer, interfaces from twisted.internet import defer, interfaces
@ -168,11 +180,25 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler) return preserve_fn(wrapped_async_request_handler)
class HttpServer: # Type of a callback method for processing requests
# it is actually called with a SynapseRequest and a kwargs dict for the params,
# but I can't figure out how to represent that.
ServletCallback = Callable[
..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
]
class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
""" """
def register_paths(self, method, path_patterns, callback): def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
""" Register a callback that gets fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
@ -180,12 +206,14 @@ class HttpServer:
an unpacked tuple. an unpacked tuple.
Args: Args:
method (str): The method to listen to. method: The HTTP method to listen to.
path_patterns (list<SRE_Pattern>): The regex used to match requests. path_patterns: The regex used to match requests.
callback (function): The function to fire if we receive a matched callback: The function to fire if we receive a matched
request. The first argument will be the request object and request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex. subsequent arguments will be any matched groups from the regex.
This should return a tuple of (code, response). This should return either tuple of (code, response), or None.
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
""" """
pass pass
@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request( def _get_handler_for_request(
self, request: SynapseRequest self, request: SynapseRequest
) -> Tuple[Callable, str, Dict[str, str]]: ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request. """Finds a callback method to handle the given request.
Returns: Returns:

View file

@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.http.server import finish_request from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_json_object_from_request, parse_json_object_from_request,
@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(), clock=hs.get_clock(),
@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE}) sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
# While its valid for us to advertise this login type generally,
if self._msc2858_enabled:
sso_flow["org.matrix.msc2858.identity_providers"] = [
_get_auth_flow_dict_for_idp(idp)
for idp in self._sso_handler.get_identity_providers().values()
]
flows.append(sso_flow)
# While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the # synapse currently only gives out these tokens as part of the
# SSO login flow. # SSO login flow.
# Generally we don't want to advertise login flows that clients # Generally we don't want to advertise login flows that clients
@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet):
return result return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
return e
class SsoRedirectServlet(RestServlet): class SsoRedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they # make sure that the relevant handlers are instantiated, so that they
@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled: if hs.config.oidc_enabled:
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
async def on_GET(self, request: SynapseRequest): def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
# expose additional endpoint for MSC2858 support
http_server.register_paths(
"GET",
client_patterns(
"/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
releases=(),
unstable=True,
),
self.on_GET,
self.__class__.__name__,
)
async def on_GET(
self, request: SynapseRequest, idp_id: Optional[str] = None
) -> None:
client_redirect_url = parse_string( client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None request, "redirectUrl", required=True, encoding=None
) )
sso_url = await self._sso_handler.handle_redirect_request( sso_url = await self._sso_handler.handle_redirect_request(
request, client_redirect_url request, client_redirect_url, idp_id,
) )
logger.info("Redirecting to %s", sso_url) logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url) request.redirect(sso_url)

View file

@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL # the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
class LoginRestServletTestCase(unittest.HomeserverTestCase): class LoginRestServletTestCase(unittest.HomeserverTestCase):
@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
d["/_synapse/oidc"] = OIDCResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs)
return d return d
def test_get_login_flows(self):
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.sso"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS
self.assertCountEqual(channel.json_body["flows"], expected_flows)
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_get_msc2858_login_flows(self):
"""The SSO flow should include IdP info if MSC2858 is enabled"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
# stick the flows results in a dict by type
flow_results = {} # type: Dict[str, Any]
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
)
flow_results[flow_type] = f
self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
sso_flow = flow_results.pop("m.login.sso")
# we should have a set of IdPs
self.assertCountEqual(
sso_flow["org.matrix.msc2858.identity_providers"],
[
{"id": "cas", "name": "CAS"},
{"id": "saml", "name": "SAML"},
{"id": "oidc-idp1", "name": "IDP1"},
{"id": "oidc", "name": "OIDC"},
],
)
# the rest of the flows are simple
expected_flows = [
{"type": "m.login.cas"},
{"type": "m.login.token"},
{"type": "m.login.password"},
] + ADDITIONAL_LOGIN_FLOWS
self.assertCountEqual(flow_results.values(), expected_flows)
def test_multi_sso_redirect(self): def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_msc2858_disabled(self):
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_unknown(self):
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@override_config({"experimental_features": {"msc2858_enabled": True}})
def test_client_idp_redirect_to_oidc(self):
"""If the client pick a known IdP, redirect to it"""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
@staticmethod @staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = " prefix = key + " = "

View file

@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None):
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource:
def __init__(self, prefix=""): def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix self.prefix = prefix