forked from MirrorHub/synapse
Save the OIDC session ID (sid) with the device on login (#11482)
As a step towards allowing back-channel logout for OIDC.
This commit is contained in:
parent
8b4b153c9e
commit
a15a893df8
15 changed files with 370 additions and 65 deletions
1
changelog.d/11482.misc
Normal file
1
changelog.d/11482.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Save the OpenID Connect session ID on login.
|
|
@ -39,6 +39,7 @@ import attr
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||||
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
@ -182,8 +183,11 @@ class LoginTokenAttributes:
|
||||||
|
|
||||||
user_id = attr.ib(type=str)
|
user_id = attr.ib(type=str)
|
||||||
|
|
||||||
# the SSO Identity Provider that the user authenticated with, to get this token
|
|
||||||
auth_provider_id = attr.ib(type=str)
|
auth_provider_id = attr.ib(type=str)
|
||||||
|
"""The SSO Identity Provider that the user authenticated with, to get this token."""
|
||||||
|
|
||||||
|
auth_provider_session_id = attr.ib(type=Optional[str])
|
||||||
|
"""The session ID advertised by the SSO Identity Provider."""
|
||||||
|
|
||||||
|
|
||||||
class AuthHandler:
|
class AuthHandler:
|
||||||
|
@ -1650,6 +1654,7 @@ class AuthHandler:
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
new_user: bool = False,
|
new_user: bool = False,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Having figured out a mxid for this user, complete the HTTP request
|
"""Having figured out a mxid for this user, complete the HTTP request
|
||||||
|
|
||||||
|
@ -1665,6 +1670,7 @@ class AuthHandler:
|
||||||
during successful login. Must be JSON serializable.
|
during successful login. Must be JSON serializable.
|
||||||
new_user: True if we should use wording appropriate to a user who has just
|
new_user: True if we should use wording appropriate to a user who has just
|
||||||
registered.
|
registered.
|
||||||
|
auth_provider_session_id: The session ID from the SSO IdP received during login.
|
||||||
"""
|
"""
|
||||||
# If the account has been deactivated, do not proceed with the login
|
# If the account has been deactivated, do not proceed with the login
|
||||||
# flow.
|
# flow.
|
||||||
|
@ -1685,6 +1691,7 @@ class AuthHandler:
|
||||||
extra_attributes,
|
extra_attributes,
|
||||||
new_user=new_user,
|
new_user=new_user,
|
||||||
user_profile_data=profile,
|
user_profile_data=profile,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _complete_sso_login(
|
def _complete_sso_login(
|
||||||
|
@ -1696,6 +1703,7 @@ class AuthHandler:
|
||||||
extra_attributes: Optional[JsonDict] = None,
|
extra_attributes: Optional[JsonDict] = None,
|
||||||
new_user: bool = False,
|
new_user: bool = False,
|
||||||
user_profile_data: Optional[ProfileInfo] = None,
|
user_profile_data: Optional[ProfileInfo] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
The synchronous portion of complete_sso_login.
|
The synchronous portion of complete_sso_login.
|
||||||
|
@ -1717,7 +1725,9 @@ class AuthHandler:
|
||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||||
registered_user_id, auth_provider_id=auth_provider_id
|
registered_user_id,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append the login token to the original redirect URL (i.e. with its query
|
# Append the login token to the original redirect URL (i.e. with its query
|
||||||
|
@ -1822,6 +1832,7 @@ class MacaroonGenerator:
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
auth_provider_id: str,
|
auth_provider_id: str,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
duration_in_ms: int = (2 * 60 * 1000),
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
) -> str:
|
) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
@ -1830,6 +1841,10 @@ class MacaroonGenerator:
|
||||||
expiry = now + duration_in_ms
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
||||||
|
if auth_provider_session_id is not None:
|
||||||
|
macaroon.add_first_party_caveat(
|
||||||
|
"auth_provider_session_id = %s" % (auth_provider_session_id,)
|
||||||
|
)
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
||||||
|
@ -1851,15 +1866,28 @@ class MacaroonGenerator:
|
||||||
user_id = get_value_from_macaroon(macaroon, "user_id")
|
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||||
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
||||||
|
|
||||||
|
auth_provider_session_id: Optional[str] = None
|
||||||
|
try:
|
||||||
|
auth_provider_session_id = get_value_from_macaroon(
|
||||||
|
macaroon, "auth_provider_session_id"
|
||||||
|
)
|
||||||
|
except MacaroonVerificationFailedException:
|
||||||
|
pass
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_exact("gen = 1")
|
v.satisfy_exact("gen = 1")
|
||||||
v.satisfy_exact("type = login")
|
v.satisfy_exact("type = login")
|
||||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||||
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
||||||
|
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
|
||||||
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
||||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||||
|
|
||||||
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
|
return LoginTokenAttributes(
|
||||||
|
user_id=user_id,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
|
|
@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
initial_device_display_name: Optional[str] = None,
|
initial_device_display_name: Optional[str] = None,
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
If the given device has not been registered, register it with the
|
If the given device has not been registered, register it with the
|
||||||
|
@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
user_id: @user:id
|
user_id: @user:id
|
||||||
device_id: device id supplied by client
|
device_id: device id supplied by client
|
||||||
initial_device_display_name: device display name from client
|
initial_device_display_name: device display name from client
|
||||||
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
|
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
|
||||||
Returns:
|
Returns:
|
||||||
device id (generated if none was supplied)
|
device id (generated if none was supplied)
|
||||||
"""
|
"""
|
||||||
|
@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
initial_device_display_name=initial_device_display_name,
|
initial_device_display_name=initial_device_display_name,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
if new_device:
|
if new_device:
|
||||||
await self.notify_device_update(user_id, [device_id])
|
await self.notify_device_update(user_id, [device_id])
|
||||||
|
@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=new_device_id,
|
device_id=new_device_id,
|
||||||
initial_device_display_name=initial_device_display_name,
|
initial_device_display_name=initial_device_display_name,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
if new_device:
|
if new_device:
|
||||||
await self.notify_device_update(user_id, [new_device_id])
|
await self.notify_device_update(user_id, [new_device_id])
|
||||||
|
|
|
@ -23,7 +23,7 @@ from authlib.common.security import generate_token
|
||||||
from authlib.jose import JsonWebToken, jwt
|
from authlib.jose import JsonWebToken, jwt
|
||||||
from authlib.oauth2.auth import ClientAuth
|
from authlib.oauth2.auth import ClientAuth
|
||||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
from authlib.oidc.core import CodeIDToken, UserInfo
|
||||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||||
from jinja2 import Environment, Template
|
from jinja2 import Environment, Template
|
||||||
from pymacaroons.exceptions import (
|
from pymacaroons.exceptions import (
|
||||||
|
@ -117,7 +117,8 @@ class OidcHandler:
|
||||||
for idp_id, p in self._providers.items():
|
for idp_id, p in self._providers.items():
|
||||||
try:
|
try:
|
||||||
await p.load_metadata()
|
await p.load_metadata()
|
||||||
await p.load_jwks()
|
if not p._uses_userinfo:
|
||||||
|
await p.load_jwks()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Error while initialising OIDC provider %r" % (idp_id,)
|
"Error while initialising OIDC provider %r" % (idp_id,)
|
||||||
|
@ -498,10 +499,6 @@ class OidcProvider:
|
||||||
return await self._jwks.get()
|
return await self._jwks.get()
|
||||||
|
|
||||||
async def _load_jwks(self) -> JWKS:
|
async def _load_jwks(self) -> JWKS:
|
||||||
if self._uses_userinfo:
|
|
||||||
# We're not using jwt signing, return an empty jwk set
|
|
||||||
return {"keys": []}
|
|
||||||
|
|
||||||
metadata = await self.load_metadata()
|
metadata = await self.load_metadata()
|
||||||
|
|
||||||
# Load the JWKS using the `jwks_uri` metadata.
|
# Load the JWKS using the `jwks_uri` metadata.
|
||||||
|
@ -663,7 +660,7 @@ class OidcProvider:
|
||||||
|
|
||||||
return UserInfo(resp)
|
return UserInfo(resp)
|
||||||
|
|
||||||
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
|
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
||||||
"""Return an instance of UserInfo from token's ``id_token``.
|
"""Return an instance of UserInfo from token's ``id_token``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -673,7 +670,7 @@ class OidcProvider:
|
||||||
request. This value should match the one inside the token.
|
request. This value should match the one inside the token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An object representing the user.
|
The decoded claims in the ID token.
|
||||||
"""
|
"""
|
||||||
metadata = await self.load_metadata()
|
metadata = await self.load_metadata()
|
||||||
claims_params = {
|
claims_params = {
|
||||||
|
@ -684,9 +681,6 @@ class OidcProvider:
|
||||||
# If we got an `access_token`, there should be an `at_hash` claim
|
# If we got an `access_token`, there should be an `at_hash` claim
|
||||||
# in the `id_token` that we can check against.
|
# in the `id_token` that we can check against.
|
||||||
claims_params["access_token"] = token["access_token"]
|
claims_params["access_token"] = token["access_token"]
|
||||||
claims_cls = CodeIDToken
|
|
||||||
else:
|
|
||||||
claims_cls = ImplicitIDToken
|
|
||||||
|
|
||||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||||
jwt = JsonWebToken(alg_values)
|
jwt = JsonWebToken(alg_values)
|
||||||
|
@ -703,7 +697,7 @@ class OidcProvider:
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
id_token,
|
id_token,
|
||||||
key=jwk_set,
|
key=jwk_set,
|
||||||
claims_cls=claims_cls,
|
claims_cls=CodeIDToken,
|
||||||
claims_options=claim_options,
|
claims_options=claim_options,
|
||||||
claims_params=claims_params,
|
claims_params=claims_params,
|
||||||
)
|
)
|
||||||
|
@ -713,7 +707,7 @@ class OidcProvider:
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
id_token,
|
id_token,
|
||||||
key=jwk_set,
|
key=jwk_set,
|
||||||
claims_cls=claims_cls,
|
claims_cls=CodeIDToken,
|
||||||
claims_options=claim_options,
|
claims_options=claim_options,
|
||||||
claims_params=claims_params,
|
claims_params=claims_params,
|
||||||
)
|
)
|
||||||
|
@ -721,7 +715,8 @@ class OidcProvider:
|
||||||
logger.debug("Decoded id_token JWT %r; validating", claims)
|
logger.debug("Decoded id_token JWT %r; validating", claims)
|
||||||
|
|
||||||
claims.validate(leeway=120) # allows 2 min of clock skew
|
claims.validate(leeway=120) # allows 2 min of clock skew
|
||||||
return UserInfo(claims)
|
|
||||||
|
return claims
|
||||||
|
|
||||||
async def handle_redirect_request(
|
async def handle_redirect_request(
|
||||||
self,
|
self,
|
||||||
|
@ -837,8 +832,22 @@ class OidcProvider:
|
||||||
|
|
||||||
logger.debug("Successfully obtained OAuth2 token data: %r", token)
|
logger.debug("Successfully obtained OAuth2 token data: %r", token)
|
||||||
|
|
||||||
# Now that we have a token, get the userinfo, either by decoding the
|
# If there is an id_token, it should be validated, regardless of the
|
||||||
# `id_token` or by fetching the `userinfo_endpoint`.
|
# userinfo endpoint is used or not.
|
||||||
|
if token.get("id_token") is not None:
|
||||||
|
try:
|
||||||
|
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
|
||||||
|
sid = id_token.get("sid")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Invalid id_token")
|
||||||
|
self._sso_handler.render_error(request, "invalid_token", str(e))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
id_token = None
|
||||||
|
sid = None
|
||||||
|
|
||||||
|
# Now that we have a token, get the userinfo either from the `id_token`
|
||||||
|
# claims or by fetching the `userinfo_endpoint`.
|
||||||
if self._uses_userinfo:
|
if self._uses_userinfo:
|
||||||
try:
|
try:
|
||||||
userinfo = await self._fetch_userinfo(token)
|
userinfo = await self._fetch_userinfo(token)
|
||||||
|
@ -846,13 +855,14 @@ class OidcProvider:
|
||||||
logger.exception("Could not fetch userinfo")
|
logger.exception("Could not fetch userinfo")
|
||||||
self._sso_handler.render_error(request, "fetch_error", str(e))
|
self._sso_handler.render_error(request, "fetch_error", str(e))
|
||||||
return
|
return
|
||||||
|
elif id_token is not None:
|
||||||
|
userinfo = UserInfo(id_token)
|
||||||
else:
|
else:
|
||||||
try:
|
logger.error("Missing id_token in token response")
|
||||||
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
|
self._sso_handler.render_error(
|
||||||
except Exception as e:
|
request, "invalid_token", "Missing id_token in token response"
|
||||||
logger.exception("Invalid id_token")
|
)
|
||||||
self._sso_handler.render_error(request, "invalid_token", str(e))
|
return
|
||||||
return
|
|
||||||
|
|
||||||
# first check if we're doing a UIA
|
# first check if we're doing a UIA
|
||||||
if session_data.ui_auth_session_id:
|
if session_data.ui_auth_session_id:
|
||||||
|
@ -884,7 +894,7 @@ class OidcProvider:
|
||||||
# Call the mapper to register/login the user
|
# Call the mapper to register/login the user
|
||||||
try:
|
try:
|
||||||
await self._complete_oidc_login(
|
await self._complete_oidc_login(
|
||||||
userinfo, token, request, session_data.client_redirect_url
|
userinfo, token, request, session_data.client_redirect_url, sid
|
||||||
)
|
)
|
||||||
except MappingException as e:
|
except MappingException as e:
|
||||||
logger.exception("Could not map user")
|
logger.exception("Could not map user")
|
||||||
|
@ -896,6 +906,7 @@ class OidcProvider:
|
||||||
token: Token,
|
token: Token,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
client_redirect_url: str,
|
client_redirect_url: str,
|
||||||
|
sid: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given a UserInfo response, complete the login flow
|
"""Given a UserInfo response, complete the login flow
|
||||||
|
|
||||||
|
@ -1008,6 +1019,7 @@ class OidcProvider:
|
||||||
oidc_response_to_user_attributes,
|
oidc_response_to_user_attributes,
|
||||||
grandfather_existing_users,
|
grandfather_existing_users,
|
||||||
extra_attributes,
|
extra_attributes,
|
||||||
|
auth_provider_session_id=sid,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|
||||||
|
|
|
@ -746,6 +746,7 @@ class RegistrationHandler:
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> Tuple[str, str, Optional[int], Optional[str]]:
|
) -> Tuple[str, str, Optional[int], Optional[str]]:
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
|
@ -756,9 +757,9 @@ class RegistrationHandler:
|
||||||
device_id: The device ID to check, or None to generate a new one.
|
device_id: The device ID to check, or None to generate a new one.
|
||||||
initial_display_name: An optional display name for the device.
|
initial_display_name: An optional display name for the device.
|
||||||
is_guest: Whether this is a guest account
|
is_guest: Whether this is a guest account
|
||||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
prometheus metrics).
|
|
||||||
should_issue_refresh_token: Whether it should also issue a refresh token
|
should_issue_refresh_token: Whether it should also issue a refresh token
|
||||||
|
auth_provider_session_id: The session ID received during login from the SSO IdP.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of device ID, access token, access token expiration time and refresh token
|
Tuple of device ID, access token, access token expiration time and refresh token
|
||||||
"""
|
"""
|
||||||
|
@ -769,6 +770,8 @@ class RegistrationHandler:
|
||||||
is_guest=is_guest,
|
is_guest=is_guest,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
login_counter.labels(
|
login_counter.labels(
|
||||||
|
@ -791,6 +794,8 @@ class RegistrationHandler:
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
is_appservice_ghost: bool = False,
|
is_appservice_ghost: bool = False,
|
||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> LoginDict:
|
) -> LoginDict:
|
||||||
"""Helper for register_device
|
"""Helper for register_device
|
||||||
|
|
||||||
|
@ -822,7 +827,11 @@ class RegistrationHandler:
|
||||||
refresh_token_id = None
|
refresh_token_id = None
|
||||||
|
|
||||||
registered_device_id = await self.device_handler.check_device_registered(
|
registered_device_id = await self.device_handler.check_device_registered(
|
||||||
user_id, device_id, initial_display_name
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
if is_guest:
|
if is_guest:
|
||||||
assert access_token_expiry is None
|
assert access_token_expiry is None
|
||||||
|
|
|
@ -365,6 +365,7 @@ class SsoHandler:
|
||||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
|
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
|
||||||
extra_login_attributes: Optional[JsonDict] = None,
|
extra_login_attributes: Optional[JsonDict] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||||
|
@ -415,6 +416,8 @@ class SsoHandler:
|
||||||
extra_login_attributes: An optional dictionary of extra
|
extra_login_attributes: An optional dictionary of extra
|
||||||
attributes to be provided to the client in the login response.
|
attributes to be provided to the client in the login response.
|
||||||
|
|
||||||
|
auth_provider_session_id: An optional session ID from the IdP.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MappingException if there was a problem mapping the response to a user.
|
MappingException if there was a problem mapping the response to a user.
|
||||||
RedirectException: if the mapping provider needs to redirect the user
|
RedirectException: if the mapping provider needs to redirect the user
|
||||||
|
@ -490,6 +493,7 @@ class SsoHandler:
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
extra_login_attributes,
|
extra_login_attributes,
|
||||||
new_user=new_user,
|
new_user=new_user,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _call_attribute_mapper(
|
async def _call_attribute_mapper(
|
||||||
|
|
|
@ -626,6 +626,7 @@ class ModuleApi:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
duration_in_ms: int = (2 * 60 * 1000),
|
duration_in_ms: int = (2 * 60 * 1000),
|
||||||
auth_provider_id: str = "",
|
auth_provider_id: str = "",
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a login token suitable for m.login.token authentication
|
"""Generate a login token suitable for m.login.token authentication
|
||||||
|
|
||||||
|
@ -643,6 +644,7 @@ class ModuleApi:
|
||||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
||||||
user_id,
|
user_id,
|
||||||
auth_provider_id,
|
auth_provider_id,
|
||||||
|
auth_provider_session_id,
|
||||||
duration_in_ms,
|
duration_in_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
is_guest,
|
is_guest,
|
||||||
is_appservice_ghost,
|
is_appservice_ghost,
|
||||||
should_issue_refresh_token,
|
should_issue_refresh_token,
|
||||||
|
auth_provider_id,
|
||||||
|
auth_provider_session_id,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
"is_guest": is_guest,
|
"is_guest": is_guest,
|
||||||
"is_appservice_ghost": is_appservice_ghost,
|
"is_appservice_ghost": is_appservice_ghost,
|
||||||
"should_issue_refresh_token": should_issue_refresh_token,
|
"should_issue_refresh_token": should_issue_refresh_token,
|
||||||
|
"auth_provider_id": auth_provider_id,
|
||||||
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
|
@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
is_appservice_ghost = content["is_appservice_ghost"]
|
is_appservice_ghost = content["is_appservice_ghost"]
|
||||||
should_issue_refresh_token = content["should_issue_refresh_token"]
|
should_issue_refresh_token = content["should_issue_refresh_token"]
|
||||||
|
auth_provider_id = content["auth_provider_id"]
|
||||||
|
auth_provider_session_id = content["auth_provider_session_id"]
|
||||||
|
|
||||||
res = await self.registration_handler.register_device_inner(
|
res = await self.registration_handler.register_device_inner(
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
is_guest,
|
is_guest,
|
||||||
is_appservice_ghost=is_appservice_ghost,
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
|
@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet):
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
|
@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet):
|
||||||
create_non_existent_users: Whether to create the user if they don't
|
create_non_existent_users: Whether to create the user if they don't
|
||||||
exist. Defaults to False.
|
exist. Defaults to False.
|
||||||
ratelimit: Whether to ratelimit the login request.
|
ratelimit: Whether to ratelimit the login request.
|
||||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
prometheus metrics).
|
|
||||||
should_issue_refresh_token: True if this login should issue
|
should_issue_refresh_token: True if this login should issue
|
||||||
a refresh token alongside the access token.
|
a refresh token alongside the access token.
|
||||||
|
auth_provider_session_id: The session ID got during login from the SSO IdP.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: Dictionary of account information after successful login.
|
result: Dictionary of account information after successful login.
|
||||||
|
@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet):
|
||||||
initial_display_name,
|
initial_display_name,
|
||||||
auth_provider_id=auth_provider_id,
|
auth_provider_id=auth_provider_id,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = LoginResponse(
|
result = LoginResponse(
|
||||||
|
@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet):
|
||||||
self.auth_handler._sso_login_callback,
|
self.auth_handler._sso_login_callback,
|
||||||
auth_provider_id=res.auth_provider_id,
|
auth_provider_id=res.auth_provider_id,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
auth_provider_session_id=res.auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(
|
async def _do_jwt_login(
|
||||||
|
|
|
@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return {d["device_id"]: d for d in devices}
|
return {d["device_id"]: d for d in devices}
|
||||||
|
|
||||||
|
async def get_devices_by_auth_provider_session_id(
|
||||||
|
self, auth_provider_id: str, auth_provider_session_id: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: The SSO IdP ID as defined in the server config
|
||||||
|
auth_provider_session_id: The session ID within the IdP
|
||||||
|
Returns:
|
||||||
|
A list of dicts containing the device_id and the user_id of each device
|
||||||
|
"""
|
||||||
|
return await self.db_pool.simple_select_list(
|
||||||
|
table="device_auth_providers",
|
||||||
|
keyvalues={
|
||||||
|
"auth_provider_id": auth_provider_id,
|
||||||
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
|
},
|
||||||
|
retcols=("user_id", "device_id"),
|
||||||
|
desc="get_devices_by_auth_provider_session_id",
|
||||||
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_device_updates_by_remote(
|
async def get_device_updates_by_remote(
|
||||||
self, destination: str, from_stream_id: int, limit: int
|
self, destination: str, from_stream_id: int, limit: int
|
||||||
|
@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def store_device(
|
async def store_device(
|
||||||
self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
initial_device_display_name: Optional[str],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
auth_provider_session_id: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Ensure the given device is known; add it to the store if not
|
"""Ensure the given device is known; add it to the store if not
|
||||||
|
|
||||||
|
@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
device_id: id of device
|
device_id: id of device
|
||||||
initial_device_display_name: initial displayname of the device.
|
initial_device_display_name: initial displayname of the device.
|
||||||
Ignored if device exists.
|
Ignored if device exists.
|
||||||
|
auth_provider_id: The SSO IdP the user used, if any.
|
||||||
|
auth_provider_session_id: The session ID (sid) got from a OIDC login.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Whether the device was inserted or an existing device existed with that ID.
|
Whether the device was inserted or an existing device existed with that ID.
|
||||||
|
@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
if hidden:
|
if hidden:
|
||||||
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
if auth_provider_id and auth_provider_session_id:
|
||||||
|
await self.db_pool.simple_insert(
|
||||||
|
"device_auth_providers",
|
||||||
|
values={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"auth_provider_id": auth_provider_id,
|
||||||
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
|
},
|
||||||
|
desc="store_device_auth_provider",
|
||||||
|
)
|
||||||
|
|
||||||
self.device_id_exists_cache.set(key, True)
|
self.device_id_exists_cache.set(key, True)
|
||||||
return inserted
|
return inserted
|
||||||
except StoreError:
|
except StoreError:
|
||||||
|
@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="device_auth_providers",
|
||||||
|
column="device_id",
|
||||||
|
values=device_ids,
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
|
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
|
||||||
for device_id in device_ids:
|
for device_id in device_ids:
|
||||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Track the auth provider used by each login as well as the session ID
|
||||||
|
CREATE TABLE device_auth_providers (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
auth_provider_id TEXT NOT NULL,
|
||||||
|
auth_provider_session_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX device_auth_providers_devices
|
||||||
|
ON device_auth_providers (user_id, device_id);
|
||||||
|
CREATE INDEX device_auth_providers_sessions
|
||||||
|
ON device_auth_providers (auth_provider_id, auth_provider_session_id);
|
|
@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
self.user1, "", 5000
|
self.user1, "", duration_in_ms=5000
|
||||||
)
|
)
|
||||||
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||||
self.assertEqual(self.user1, res.user_id)
|
self.assertEqual(self.user1, res.user_id)
|
||||||
|
@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
self.user1, "", 5000
|
self.user1, "", duration_in_ms=5000
|
||||||
)
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
|
@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def _get_macaroon(self):
|
def _get_macaroon(self):
|
||||||
token = self.macaroon_generator.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
self.user1, "", 5000
|
self.user1, "", duration_in_ms=5000
|
||||||
)
|
)
|
||||||
return pymacaroons.Macaroon.deserialize(token)
|
return pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
|
@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
"@test_user:test",
|
||||||
|
"cas",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_existing_user(self):
|
def test_map_cas_user_to_existing_user(self):
|
||||||
|
@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
"@test_user:test",
|
||||||
|
"cas",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
|
@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
"@test_user:test",
|
||||||
|
"cas",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_cas_user_to_invalid_localpart(self):
|
def test_map_cas_user_to_invalid_localpart(self):
|
||||||
|
@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
|
"@f=c3=b6=c3=b6:test",
|
||||||
|
"cas",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
"@test_user:test",
|
||||||
|
"cas",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
with patch.object(self.provider, "load_metadata", patched_load_metadata):
|
with patch.object(self.provider, "load_metadata", patched_load_metadata):
|
||||||
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
||||||
|
|
||||||
# Return empty key set if JWKS are not used
|
|
||||||
self.provider._scopes = [] # not asking the openid scope
|
|
||||||
self.http_client.get_json.reset_mock()
|
|
||||||
jwks = self.get_success(self.provider.load_jwks(force=True))
|
|
||||||
self.http_client.get_json.assert_not_called()
|
|
||||||
self.assertEqual(jwks, {"keys": []})
|
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self):
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
|
@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
|
expected_user_id,
|
||||||
|
"oidc",
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
|
@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.provider._fetch_userinfo.reset_mock()
|
self.provider._fetch_userinfo.reset_mock()
|
||||||
|
|
||||||
# With userinfo fetching
|
# With userinfo fetching
|
||||||
self.provider._scopes = [] # do not ask the "openid" scope
|
self.provider._user_profile_method = "userinfo_endpoint"
|
||||||
|
token = {
|
||||||
|
"type": "bearer",
|
||||||
|
"access_token": "access_token",
|
||||||
|
}
|
||||||
|
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
|
expected_user_id,
|
||||||
|
"oidc",
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
self.provider._exchange_code.assert_called_once_with(code)
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
self.provider._parse_id_token.assert_not_called()
|
self.provider._parse_id_token.assert_not_called()
|
||||||
self.provider._fetch_userinfo.assert_called_once_with(token)
|
self.provider._fetch_userinfo.assert_called_once_with(token)
|
||||||
self.render_error.assert_not_called()
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
|
# With an ID token, userinfo fetching and sid in the ID token
|
||||||
|
self.provider._user_profile_method = "userinfo_endpoint"
|
||||||
|
token = {
|
||||||
|
"type": "bearer",
|
||||||
|
"access_token": "access_token",
|
||||||
|
"id_token": "id_token",
|
||||||
|
}
|
||||||
|
id_token = {
|
||||||
|
"sid": "abcdefgh",
|
||||||
|
}
|
||||||
|
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
||||||
|
self.provider._exchange_code = simple_async_mock(return_value=token)
|
||||||
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
self.provider._fetch_userinfo.reset_mock()
|
||||||
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
|
||||||
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
|
expected_user_id,
|
||||||
|
"oidc",
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=id_token["sid"],
|
||||||
|
)
|
||||||
|
self.provider._exchange_code.assert_called_once_with(code)
|
||||||
|
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||||
|
self.provider._fetch_userinfo.assert_called_once_with(token)
|
||||||
|
self.render_error.assert_not_called()
|
||||||
|
|
||||||
# Handle userinfo fetching error
|
# Handle userinfo fetching error
|
||||||
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
|
@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
client_redirect_url,
|
client_redirect_url,
|
||||||
{"phone": "1234567"},
|
{"phone": "1234567"},
|
||||||
new_user=True,
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
|
@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
|
"@test_user:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
|
"@test_user_2:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
user.to_string(),
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
user.to_string(),
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
user.to_string(),
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
"@TEST_USER_2:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": DEFAULT_CONFIG})
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
|
@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
|
"@test_user1:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@tester:test", "oidc", ANY, ANY, None, new_user=True
|
"@tester:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@tester:test", "oidc", ANY, ANY, None, new_user=True
|
"@tester:test",
|
||||||
|
"oidc",
|
||||||
|
ANY,
|
||||||
|
ANY,
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo(
|
||||||
|
|
||||||
handler = hs.get_oidc_handler()
|
handler = hs.get_oidc_handler()
|
||||||
provider = handler._providers["oidc"]
|
provider = handler._providers["oidc"]
|
||||||
provider._exchange_code = simple_async_mock(return_value={})
|
provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
||||||
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
||||||
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
"@test_user:test",
|
||||||
|
"saml",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||||
|
@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "saml", request, "", None, new_user=False
|
"@test_user:test",
|
||||||
|
"saml",
|
||||||
|
request,
|
||||||
|
"",
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subsequent calls should map to the same mxid.
|
# Subsequent calls should map to the same mxid.
|
||||||
|
@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
self.handler._handle_authn_response(request, saml_response, "")
|
self.handler._handle_authn_response(request, saml_response, "")
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "saml", request, "", None, new_user=False
|
"@test_user:test",
|
||||||
|
"saml",
|
||||||
|
request,
|
||||||
|
"",
|
||||||
|
None,
|
||||||
|
new_user=False,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_map_saml_response_to_invalid_localpart(self):
|
def test_map_saml_response_to_invalid_localpart(self):
|
||||||
|
@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# test_user is already taken, so test_user1 gets registered instead.
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user1:test", "saml", request, "", None, new_user=True
|
"@test_user1:test",
|
||||||
|
"saml",
|
||||||
|
request,
|
||||||
|
"",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
auth_handler.complete_sso_login.reset_mock()
|
auth_handler.complete_sso_login.reset_mock()
|
||||||
|
|
||||||
|
@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# check that the auth handler got called as expected
|
# check that the auth handler got called as expected
|
||||||
auth_handler.complete_sso_login.assert_called_once_with(
|
auth_handler.complete_sso_login.assert_called_once_with(
|
||||||
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
"@test_user:test",
|
||||||
|
"saml",
|
||||||
|
request,
|
||||||
|
"redirect_uri",
|
||||||
|
None,
|
||||||
|
new_user=True,
|
||||||
|
auth_provider_session_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue