Preparatory refactors of OidcHandler (#9067)

Some light refactoring of OidcHandler, in preparation for bigger things:

  * remove inheritance from deprecated BaseHandler
  * add an object to hold the things that go into a session cookie
  * factor out a separate class for manipulating said cookies
This commit is contained in:
Richard van der Hoff 2021-01-13 10:26:12 +00:00 committed by GitHub
parent 7a2e9b549d
commit bc4bf7b384
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 201 additions and 165 deletions

1
changelog.d/9067.feature Normal file
View file

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode from urllib.parse import urlencode
import attr import attr
@ -35,7 +35,6 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -85,12 +84,15 @@ class OidcError(Exception):
return self.error return self.error
class OidcHandler(BaseHandler): class OidcHandler:
"""Handles requests related to the OpenID Connect login flow. """Handles requests related to the OpenID Connect login flow.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self._store = hs.get_datastore()
self._token_generator = OidcSessionTokenGenerator(hs)
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str] self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -116,7 +118,6 @@ class OidcHandler(BaseHandler):
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table # identifier for the external_ids table
self.idp_id = "oidc" self.idp_id = "oidc"
@ -519,11 +520,13 @@ class OidcHandler(BaseHandler):
if not client_redirect_url: if not client_redirect_url:
client_redirect_url = b"" client_redirect_url = b""
cookie = self._generate_oidc_session_token( cookie = self._token_generator.generate_oidc_session_token(
state=state, state=state,
nonce=nonce, session_data=OidcSessionData(
client_redirect_url=client_redirect_url.decode(), nonce=nonce,
ui_auth_session_id=ui_auth_session_id, client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
),
) )
request.addCookie( request.addCookie(
SESSION_COOKIE_NAME, SESSION_COOKIE_NAME,
@ -628,11 +631,9 @@ class OidcHandler(BaseHandler):
# Deserialize the session token and verify it. # Deserialize the session token and verify it.
try: try:
( session_data = self._token_generator.verify_oidc_session_token(
nonce, session, state
client_redirect_url, )
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e: except MacaroonDeserializationException as e:
logger.exception("Invalid session") logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e)) self._sso_handler.render_error(request, "invalid_session", str(e))
@ -674,14 +675,14 @@ class OidcHandler(BaseHandler):
else: else:
logger.debug("Extracting userinfo from id_token") logger.debug("Extracting userinfo from id_token")
try: try:
userinfo = await self._parse_id_token(token, nonce=nonce) userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e: except Exception as e:
logger.exception("Invalid id_token") logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e)) self._sso_handler.render_error(request, "invalid_token", str(e))
return return
# first check if we're doing a UIA # first check if we're doing a UIA
if ui_auth_session_id: if session_data.ui_auth_session_id:
try: try:
remote_user_id = self._remote_id_from_userinfo(userinfo) remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e: except Exception as e:
@ -690,7 +691,7 @@ class OidcHandler(BaseHandler):
return return
return await self._sso_handler.complete_sso_ui_auth_request( return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, remote_user_id, ui_auth_session_id, request self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
) )
# otherwise, it's a login # otherwise, it's a login
@ -698,133 +699,12 @@ class OidcHandler(BaseHandler):
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
await self._complete_oidc_login( await self._complete_oidc_login(
userinfo, token, request, client_redirect_url userinfo, token, request, session_data.client_redirect_url
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.clock.time_msec()
return now < expiry
async def _complete_oidc_login( async def _complete_oidc_login(
self, self,
userinfo: UserInfo, userinfo: UserInfo,
@ -901,8 +781,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it. # and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0) attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string() user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id) users = await self._store.get_users_by_id_case_insensitive(user_id)
if users: if users:
# If an existing matrix ID is returned, then use it. # If an existing matrix ID is returned, then use it.
if len(users) == 1: if len(users) == 1:
@ -954,6 +834,148 @@ class OidcHandler(BaseHandler):
return str(remote_user_id) return str(remote_user_id)
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._server_name = hs.hostname
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
def generate_oidc_session_token(
self,
state: str,
session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def verify_oidc_session_token(
self, session: bytes, state: str
) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The data extracted from the session cookie
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
UserAttributeDict = TypedDict( UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import re import re
from typing import Dict from typing import Dict, Optional
from urllib.parse import parse_qs, urlencode, urlparse from urllib.parse import parse_qs, urlencode, urlparse
from mock import ANY, Mock, patch from mock import ANY, Mock, patch
@ -349,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
cookie = args[1] cookie = args[1]
macaroon = pymacaroons.Macaroon.deserialize(cookie) macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = self.handler._get_value_from_macaroon(macaroon, "state") state = self.handler._token_generator._get_value_from_macaroon(
nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") macaroon, "state"
redirect = self.handler._get_value_from_macaroon( )
nonce = self.handler._token_generator._get_value_from_macaroon(
macaroon, "nonce"
)
redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url" macaroon, "client_redirect_url"
) )
@ -411,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
user_agent = "Browser" user_agent = "Browser"
ip_address = "10.0.0.1" ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token( session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request = _build_callback_request( request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address code, state, session, user_agent=user_agent, ip_address=ip_address
) )
@ -500,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_session") self.assertRenderedError("invalid_session")
# Mismatching session # Mismatching session
session = self.handler._generate_oidc_session_token( session = self._generate_oidc_session_token(
state="state", state="state", nonce="nonce", client_redirect_url="http://client/redirect",
nonce="nonce",
client_redirect_url="http://client/redirect",
ui_auth_session_id=None,
) )
request.args = {} request.args = {}
request.args[b"state"] = [b"mismatching state"] request.args[b"state"] = [b"mismatching state"]
@ -623,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
state = "state" state = "state"
client_redirect_url = "http://client/redirect" client_redirect_url = "http://client/redirect"
session = self.handler._generate_oidc_session_token( session = self._generate_oidc_session_token(
state=state, state=state, nonce="nonce", client_redirect_url=client_redirect_url,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
) )
request = _build_callback_request("code", state, session) request = _build_callback_request("code", state, session)
@ -841,6 +834,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
self.assertRenderedError("mapping_error", "localpart is invalid: ") self.assertRenderedError("mapping_error", "localpart is invalid: ")
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str] = None,
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
return self.handler._token_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
),
)
class UsernamePickerTestCase(HomeserverTestCase): class UsernamePickerTestCase(HomeserverTestCase):
if not HAS_OIDC: if not HAS_OIDC:
@ -965,17 +976,19 @@ async def _make_callback_with_userinfo(
userinfo: the OIDC userinfo dict userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success. client_redirect_url: the URL to redirect to on success.
""" """
from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={}) handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo) handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo) handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state" state = "state"
session = handler._generate_oidc_session_token( session = handler._token_generator.generate_oidc_session_token(
state=state, state=state,
nonce="nonce", session_data=OidcSessionData(
client_redirect_url=client_redirect_url, nonce="nonce", client_redirect_url=client_redirect_url,
ui_auth_session_id=None, ),
) )
request = _build_callback_request("code", state, session) request = _build_callback_request("code", state, session)