0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 02:51:54 +01:00

Abstract shared SSO code. (#8765)

De-duplicates code between the SAML and OIDC implementations.
This commit is contained in:
Patrick Cloke 2020-11-17 09:46:23 -05:00 committed by GitHub
parent e487d9fabc
commit ee382025b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 120 deletions

1
changelog.d/8765.misc Normal file
View file

@ -0,0 +1 @@
Consolidate logic between the OpenID Connect and SAML code.

View file

@ -34,7 +34,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@ -83,17 +84,12 @@ class OidcError(Exception):
return self.error
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class OidcHandler:
class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -120,36 +116,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@ -571,7 +544,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error.
``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@ -609,7 +582,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description)
self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@ -619,7 +592,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return
# Remove the cookie. There is a good chance that if the callback failed
@ -637,7 +612,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return
state = request.args[b"state"][0].decode()
@ -651,17 +628,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e))
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing")
self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return
logger.debug("Exchanging code")
@ -670,7 +649,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description)
self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@ -683,7 +662,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e))
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@ -691,7 +670,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e))
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@ -705,7 +684,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@ -770,7 +749,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@ -845,7 +824,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@ -885,20 +864,14 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
# Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
@ -917,8 +890,8 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
@ -942,7 +915,8 @@ class OidcHandler:
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
return registered_user_id

View file

@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@ -42,10 +43,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@ -57,17 +54,13 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler:
class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@ -88,26 +81,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@ -130,7 +106,7 @@ class SamlHandler:
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec()
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@ -171,12 +147,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
self._render_error(
self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
self._render_error(
self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@ -184,7 +160,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
self._render_error(
self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@ -210,7 +186,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@ -226,7 +202,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@ -274,17 +250,11 @@ class SamlHandler:
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
"Looking for existing mapping for user %s:%s",
self._auth_provider_id,
remote_user_id,
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
if previously_registered_user_id:
return previously_registered_user_id
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@ -294,7 +264,7 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname
map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
logger.info(
"Looking for existing account based on mapped %s %s",
@ -302,11 +272,11 @@ class SamlHandler:
user_id,
)
users = await self._datastore.get_users_by_id_case_insensitive(user_id)
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
@ -335,8 +305,8 @@ class SamlHandler:
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
@ -348,7 +318,6 @@ class SamlHandler:
)
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
@ -356,13 +325,13 @@ class SamlHandler:
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:

90
synapse/handlers/sso.py Normal file
View file

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class SsoHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._error_template = hs.config.sso_error_template
def render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and responds with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
) -> Optional[str]:
"""
Maps the user ID of a remote IdP to a mxid for a previously seen user.
If the user has not been seen yet, this will return None.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The user ID according to the remote IdP. This might
be an e-mail address, a GUID, or some other form. It must be
unique and immutable.
Returns:
The mxid of a previously seen user.
"""
# Check if we already have a mapping for this user.
logger.info(
"Looking for existing mapping for user %s:%s",
auth_provider_id,
remote_user_id,
)
previously_registered_user_id = await self.store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
# A match was found, return the user ID.
if previously_registered_user_id is not None:
logger.info("Found existing mapping %s", previously_registered_user_id)
return previously_registered_user_id
# No match.
return None

View file

@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else:
return FollowerTypingHandler(self)
@cache_in_self
def get_sso_handler(self) -> SsoHandler:
return SsoHandler(self)
@cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)

View file

@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler = OidcHandler(hs)
# Mock the render error method.
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
return hs
@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
args = self.handler._render_error.call_args[0]
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.handler._render_error.reset_mock()
self.render_error.reset_mock()
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar",
}
user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie