mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 08:13:50 +01:00
Preparatory refactors of OidcHandler (#9067)
Some light refactoring of OidcHandler, in preparation for bigger things: * remove inheritance from deprecated BaseHandler * add an object to hold the things that go into a session cookie * factor out a separate class for manipulating said cookies
This commit is contained in:
parent
7a2e9b549d
commit
bc4bf7b384
3 changed files with 201 additions and 165 deletions
1
changelog.d/9067.feature
Normal file
1
changelog.d/9067.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add support for multiple SSO Identity Providers.
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import attr
|
||||
|
@ -35,7 +35,6 @@ from typing_extensions import TypedDict
|
|||
from twisted.web.client import readBody
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
|
@ -85,12 +84,15 @@ class OidcError(Exception):
|
|||
return self.error
|
||||
|
||||
|
||||
class OidcHandler(BaseHandler):
|
||||
class OidcHandler:
|
||||
"""Handles requests related to the OpenID Connect login flow.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
|
||||
|
@ -116,7 +118,6 @@ class OidcHandler(BaseHandler):
|
|||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
|
||||
# identifier for the external_ids table
|
||||
self.idp_id = "oidc"
|
||||
|
@ -519,11 +520,13 @@ class OidcHandler(BaseHandler):
|
|||
if not client_redirect_url:
|
||||
client_redirect_url = b""
|
||||
|
||||
cookie = self._generate_oidc_session_token(
|
||||
cookie = self._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url.decode(),
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
session_data=OidcSessionData(
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url.decode(),
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
),
|
||||
)
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
|
@ -628,11 +631,9 @@ class OidcHandler(BaseHandler):
|
|||
|
||||
# Deserialize the session token and verify it.
|
||||
try:
|
||||
(
|
||||
nonce,
|
||||
client_redirect_url,
|
||||
ui_auth_session_id,
|
||||
) = self._verify_oidc_session_token(session, state)
|
||||
session_data = self._token_generator.verify_oidc_session_token(
|
||||
session, state
|
||||
)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
|
@ -674,14 +675,14 @@ class OidcHandler(BaseHandler):
|
|||
else:
|
||||
logger.debug("Extracting userinfo from id_token")
|
||||
try:
|
||||
userinfo = await self._parse_id_token(token, nonce=nonce)
|
||||
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||
except Exception as e:
|
||||
logger.exception("Invalid id_token")
|
||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if ui_auth_session_id:
|
||||
if session_data.ui_auth_session_id:
|
||||
try:
|
||||
remote_user_id = self._remote_id_from_userinfo(userinfo)
|
||||
except Exception as e:
|
||||
|
@ -690,7 +691,7 @@ class OidcHandler(BaseHandler):
|
|||
return
|
||||
|
||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||
self.idp_id, remote_user_id, ui_auth_session_id, request
|
||||
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
|
||||
)
|
||||
|
||||
# otherwise, it's a login
|
||||
|
@ -698,133 +699,12 @@ class OidcHandler(BaseHandler):
|
|||
# Call the mapper to register/login the user
|
||||
try:
|
||||
await self._complete_oidc_login(
|
||||
userinfo, token, request, client_redirect_url
|
||||
userinfo, token, request, session_data.client_redirect_url
|
||||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
|
||||
def _generate_oidc_session_token(
|
||||
self,
|
||||
state: str,
|
||||
nonce: str,
|
||||
client_redirect_url: str,
|
||||
ui_auth_session_id: Optional[str],
|
||||
duration_in_ms: int = (60 * 60 * 1000),
|
||||
) -> str:
|
||||
"""Generates a signed token storing data about an OIDC session.
|
||||
|
||||
When Synapse initiates an authorization flow, it creates a random state
|
||||
and a random nonce. Those parameters are given to the provider and
|
||||
should be verified when the client comes back from the provider.
|
||||
It is also used to store the client_redirect_url, which is used to
|
||||
complete the SSO login flow.
|
||||
|
||||
Args:
|
||||
state: The ``state`` parameter passed to the OIDC provider.
|
||||
nonce: The ``nonce`` parameter passed to the OIDC provider.
|
||||
client_redirect_url: The URL the client gave when it initiated the
|
||||
flow.
|
||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||
None if this is a login).
|
||||
duration_in_ms: An optional duration for the token in milliseconds.
|
||||
Defaults to an hour.
|
||||
|
||||
Returns:
|
||||
A signed macaroon token with the session information.
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = session")
|
||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
|
||||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (client_redirect_url,)
|
||||
)
|
||||
if ui_auth_session_id:
|
||||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||
)
|
||||
now = self.clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
|
||||
return macaroon.serialize()
|
||||
|
||||
def _verify_oidc_session_token(
|
||||
self, session: bytes, state: str
|
||||
) -> Tuple[str, str, Optional[str]]:
|
||||
"""Verifies and extract an OIDC session token.
|
||||
|
||||
This verifies that a given session token was issued by this homeserver
|
||||
and extract the nonce and client_redirect_url caveats.
|
||||
|
||||
Args:
|
||||
session: The session token to verify
|
||||
state: The state the OIDC provider gave back
|
||||
|
||||
Returns:
|
||||
The nonce, client_redirect_url, and ui_auth_session_id for this session
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = session")
|
||||
v.satisfy_exact("state = %s" % (state,))
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||
# to always satisfy this.
|
||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
||||
# `ui_auth_session_id` from the token.
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
try:
|
||||
ui_auth_session_id = self._get_value_from_macaroon(
|
||||
macaroon, "ui_auth_session_id"
|
||||
) # type: Optional[str]
|
||||
except ValueError:
|
||||
ui_auth_session_id = None
|
||||
|
||||
return nonce, client_redirect_url, ui_auth_session_id
|
||||
|
||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
||||
Args:
|
||||
macaroon: the token
|
||||
key: the key of the caveat to extract
|
||||
|
||||
Returns:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
Exception: if the caveat was not in the macaroon
|
||||
"""
|
||||
prefix = key + " = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(prefix):
|
||||
return caveat.caveat_id[len(prefix) :]
|
||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
||||
|
||||
def _verify_expiry(self, caveat: str) -> bool:
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self.clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _complete_oidc_login(
|
||||
self,
|
||||
userinfo: UserInfo,
|
||||
|
@ -901,8 +781,8 @@ class OidcHandler(BaseHandler):
|
|||
# and attempt to match it.
|
||||
attributes = await oidc_response_to_user_attributes(failures=0)
|
||||
|
||||
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
||||
users = await self._store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
# If an existing matrix ID is returned, then use it.
|
||||
if len(users) == 1:
|
||||
|
@ -954,6 +834,148 @@ class OidcHandler(BaseHandler):
|
|||
return str(remote_user_id)
|
||||
|
||||
|
||||
class OidcSessionTokenGenerator:
|
||||
"""Methods for generating and checking OIDC Session cookies."""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._server_name = hs.hostname
|
||||
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
|
||||
|
||||
def generate_oidc_session_token(
|
||||
self,
|
||||
state: str,
|
||||
session_data: "OidcSessionData",
|
||||
duration_in_ms: int = (60 * 60 * 1000),
|
||||
) -> str:
|
||||
"""Generates a signed token storing data about an OIDC session.
|
||||
|
||||
When Synapse initiates an authorization flow, it creates a random state
|
||||
and a random nonce. Those parameters are given to the provider and
|
||||
should be verified when the client comes back from the provider.
|
||||
It is also used to store the client_redirect_url, which is used to
|
||||
complete the SSO login flow.
|
||||
|
||||
Args:
|
||||
state: The ``state`` parameter passed to the OIDC provider.
|
||||
session_data: data to include in the session token.
|
||||
duration_in_ms: An optional duration for the token in milliseconds.
|
||||
Defaults to an hour.
|
||||
|
||||
Returns:
|
||||
A signed macaroon token with the session information.
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
||||
)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = session")
|
||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
|
||||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||
)
|
||||
if session_data.ui_auth_session_id:
|
||||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
|
||||
return macaroon.serialize()
|
||||
|
||||
def verify_oidc_session_token(
|
||||
self, session: bytes, state: str
|
||||
) -> "OidcSessionData":
|
||||
"""Verifies and extract an OIDC session token.
|
||||
|
||||
This verifies that a given session token was issued by this homeserver
|
||||
and extract the nonce and client_redirect_url caveats.
|
||||
|
||||
Args:
|
||||
session: The session token to verify
|
||||
state: The state the OIDC provider gave back
|
||||
|
||||
Returns:
|
||||
The data extracted from the session cookie
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = session")
|
||||
v.satisfy_exact("state = %s" % (state,))
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||
# to always satisfy this.
|
||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
||||
# `ui_auth_session_id` from the token.
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
try:
|
||||
ui_auth_session_id = self._get_value_from_macaroon(
|
||||
macaroon, "ui_auth_session_id"
|
||||
) # type: Optional[str]
|
||||
except ValueError:
|
||||
ui_auth_session_id = None
|
||||
|
||||
return OidcSessionData(
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
|
||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
||||
Args:
|
||||
macaroon: the token
|
||||
key: the key of the caveat to extract
|
||||
|
||||
Returns:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
Exception: if the caveat was not in the macaroon
|
||||
"""
|
||||
prefix = key + " = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(prefix):
|
||||
return caveat.caveat_id[len(prefix) :]
|
||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
||||
|
||||
def _verify_expiry(self, caveat: str) -> bool:
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self._clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class OidcSessionData:
|
||||
"""The attributes which are stored in a OIDC session cookie"""
|
||||
|
||||
# The `nonce` parameter passed to the OIDC provider.
|
||||
nonce = attr.ib(type=str)
|
||||
|
||||
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||
client_redirect_url = attr.ib(type=str)
|
||||
|
||||
# The session ID of the ongoing UI Auth (None if this is a login)
|
||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
UserAttributeDict = TypedDict(
|
||||
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
from mock import ANY, Mock, patch
|
||||
|
@ -349,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
cookie = args[1]
|
||||
|
||||
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
||||
state = self.handler._get_value_from_macaroon(macaroon, "state")
|
||||
nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
|
||||
redirect = self.handler._get_value_from_macaroon(
|
||||
state = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "state"
|
||||
)
|
||||
nonce = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "nonce"
|
||||
)
|
||||
redirect = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
|
||||
|
@ -411,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
client_redirect_url = "http://client/redirect"
|
||||
user_agent = "Browser"
|
||||
ip_address = "10.0.0.1"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
|
||||
request = _build_callback_request(
|
||||
code, state, session, user_agent=user_agent, ip_address=ip_address
|
||||
)
|
||||
|
@ -500,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.assertRenderedError("invalid_session")
|
||||
|
||||
# Mismatching session
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state="state",
|
||||
nonce="nonce",
|
||||
client_redirect_url="http://client/redirect",
|
||||
ui_auth_session_id=None,
|
||||
session = self._generate_oidc_session_token(
|
||||
state="state", nonce="nonce", client_redirect_url="http://client/redirect",
|
||||
)
|
||||
request.args = {}
|
||||
request.args[b"state"] = [b"mismatching state"]
|
||||
|
@ -623,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
state = "state"
|
||||
client_redirect_url = "http://client/redirect"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
session = self._generate_oidc_session_token(
|
||||
state=state, nonce="nonce", client_redirect_url=client_redirect_url,
|
||||
)
|
||||
request = _build_callback_request("code", state, session)
|
||||
|
||||
|
@ -841,6 +834,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
self.assertRenderedError("mapping_error", "localpart is invalid: ")
|
||||
|
||||
def _generate_oidc_session_token(
|
||||
self,
|
||||
state: str,
|
||||
nonce: str,
|
||||
client_redirect_url: str,
|
||||
ui_auth_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
from synapse.handlers.oidc_handler import OidcSessionData
|
||||
|
||||
return self.handler._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
session_data=OidcSessionData(
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UsernamePickerTestCase(HomeserverTestCase):
|
||||
if not HAS_OIDC:
|
||||
|
@ -965,17 +976,19 @@ async def _make_callback_with_userinfo(
|
|||
userinfo: the OIDC userinfo dict
|
||||
client_redirect_url: the URL to redirect to on success.
|
||||
"""
|
||||
from synapse.handlers.oidc_handler import OidcSessionData
|
||||
|
||||
handler = hs.get_oidc_handler()
|
||||
handler._exchange_code = simple_async_mock(return_value={})
|
||||
handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
|
||||
state = "state"
|
||||
session = handler._generate_oidc_session_token(
|
||||
session = handler._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
session_data=OidcSessionData(
|
||||
nonce="nonce", client_redirect_url=client_redirect_url,
|
||||
),
|
||||
)
|
||||
request = _build_callback_request("code", state, session)
|
||||
|
||||
|
|
Loading…
Reference in a new issue