Save the OIDC session ID (sid) with the device on login (#11482)

As a step towards allowing back-channel logout for OIDC.
This commit is contained in:
Quentin Gliech 2021-12-06 18:43:06 +01:00 committed by GitHub
parent 8b4b153c9e
commit a15a893df8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 370 additions and 65 deletions

1
changelog.d/11482.misc Normal file
View file

@ -0,0 +1 @@
Save the OpenID Connect session ID on login.

View file

@ -39,6 +39,7 @@ import attr
import bcrypt import bcrypt
import pymacaroons import pymacaroons
import unpaddedbase64 import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException
from twisted.web.server import Request from twisted.web.server import Request
@ -182,8 +183,11 @@ class LoginTokenAttributes:
user_id = attr.ib(type=str) user_id = attr.ib(type=str)
# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str) auth_provider_id = attr.ib(type=str)
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id = attr.ib(type=Optional[str])
"""The session ID advertised by the SSO Identity Provider."""
class AuthHandler: class AuthHandler:
@ -1650,6 +1654,7 @@ class AuthHandler:
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False, new_user: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> None: ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request """Having figured out a mxid for this user, complete the HTTP request
@ -1665,6 +1670,7 @@ class AuthHandler:
during successful login. Must be JSON serializable. during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just new_user: True if we should use wording appropriate to a user who has just
registered. registered.
auth_provider_session_id: The session ID from the SSO IdP received during login.
""" """
# If the account has been deactivated, do not proceed with the login # If the account has been deactivated, do not proceed with the login
# flow. # flow.
@ -1685,6 +1691,7 @@ class AuthHandler:
extra_attributes, extra_attributes,
new_user=new_user, new_user=new_user,
user_profile_data=profile, user_profile_data=profile,
auth_provider_session_id=auth_provider_session_id,
) )
def _complete_sso_login( def _complete_sso_login(
@ -1696,6 +1703,7 @@ class AuthHandler:
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False, new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None, user_profile_data: Optional[ProfileInfo] = None,
auth_provider_session_id: Optional[str] = None,
) -> None: ) -> None:
""" """
The synchronous portion of complete_sso_login. The synchronous portion of complete_sso_login.
@ -1717,7 +1725,9 @@ class AuthHandler:
# Create a login token # Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token( login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id, auth_provider_id=auth_provider_id registered_user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
# Append the login token to the original redirect URL (i.e. with its query # Append the login token to the original redirect URL (i.e. with its query
@ -1822,6 +1832,7 @@ class MacaroonGenerator:
self, self,
user_id: str, user_id: str,
auth_provider_id: str, auth_provider_id: str,
auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000), duration_in_ms: int = (2 * 60 * 1000),
) -> str: ) -> str:
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
@ -1830,6 +1841,10 @@ class MacaroonGenerator:
expiry = now + duration_in_ms expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
if auth_provider_session_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_session_id = %s" % (auth_provider_session_id,)
)
return macaroon.serialize() return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@ -1851,15 +1866,28 @@ class MacaroonGenerator:
user_id = get_value_from_macaroon(macaroon, "user_id") user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
auth_provider_session_id: Optional[str] = None
try:
auth_provider_session_id = get_value_from_macaroon(
macaroon, "auth_provider_session_id"
)
except MacaroonVerificationFailedException:
pass
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login") v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec) satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key) v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) return LoginTokenAttributes(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
def generate_delete_pusher_token(self, user_id: str) -> str: def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)

View file

@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: str, user_id: str,
device_id: Optional[str], device_id: Optional[str],
initial_device_display_name: Optional[str] = None, initial_device_display_name: Optional[str] = None,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> str: ) -> str:
""" """
If the given device has not been registered, register it with the If the given device has not been registered, register it with the
@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: @user:id user_id: @user:id
device_id: device id supplied by client device_id: device id supplied by client
initial_device_display_name: device display name from client initial_device_display_name: device display name from client
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns: Returns:
device id (generated if none was supplied) device id (generated if none was supplied)
""" """
@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
if new_device: if new_device:
await self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [device_id])
@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id, user_id=user_id,
device_id=new_device_id, device_id=new_device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
if new_device: if new_device:
await self.notify_device_update(user_id, [new_device_id]) await self.notify_device_update(user_id, [new_device_id])

View file

@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template from jinja2 import Environment, Template
from pymacaroons.exceptions import ( from pymacaroons.exceptions import (
@ -117,7 +117,8 @@ class OidcHandler:
for idp_id, p in self._providers.items(): for idp_id, p in self._providers.items():
try: try:
await p.load_metadata() await p.load_metadata()
await p.load_jwks() if not p._uses_userinfo:
await p.load_jwks()
except Exception as e: except Exception as e:
raise Exception( raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,) "Error while initialising OIDC provider %r" % (idp_id,)
@ -498,10 +499,6 @@ class OidcProvider:
return await self._jwks.get() return await self._jwks.get()
async def _load_jwks(self) -> JWKS: async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}
metadata = await self.load_metadata() metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata. # Load the JWKS using the `jwks_uri` metadata.
@ -663,7 +660,7 @@ class OidcProvider:
return UserInfo(resp) return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``. """Return an instance of UserInfo from token's ``id_token``.
Args: Args:
@ -673,7 +670,7 @@ class OidcProvider:
request. This value should match the one inside the token. request. This value should match the one inside the token.
Returns: Returns:
An object representing the user. The decoded claims in the ID token.
""" """
metadata = await self.load_metadata() metadata = await self.load_metadata()
claims_params = { claims_params = {
@ -684,9 +681,6 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim # If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against. # in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"] claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values) jwt = JsonWebToken(alg_values)
@ -703,7 +697,7 @@ class OidcProvider:
claims = jwt.decode( claims = jwt.decode(
id_token, id_token,
key=jwk_set, key=jwk_set,
claims_cls=claims_cls, claims_cls=CodeIDToken,
claims_options=claim_options, claims_options=claim_options,
claims_params=claims_params, claims_params=claims_params,
) )
@ -713,7 +707,7 @@ class OidcProvider:
claims = jwt.decode( claims = jwt.decode(
id_token, id_token,
key=jwk_set, key=jwk_set,
claims_cls=claims_cls, claims_cls=CodeIDToken,
claims_options=claim_options, claims_options=claim_options,
claims_params=claims_params, claims_params=claims_params,
) )
@ -721,7 +715,8 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims) logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
return claims
async def handle_redirect_request( async def handle_redirect_request(
self, self,
@ -837,8 +832,22 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token) logger.debug("Successfully obtained OAuth2 token data: %r", token)
# Now that we have a token, get the userinfo, either by decoding the # If there is an id_token, it should be validated, regardless of the
# `id_token` or by fetching the `userinfo_endpoint`. # userinfo endpoint is used or not.
if token.get("id_token") is not None:
try:
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
sid = id_token.get("sid")
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
else:
id_token = None
sid = None
# Now that we have a token, get the userinfo either from the `id_token`
# claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo: if self._uses_userinfo:
try: try:
userinfo = await self._fetch_userinfo(token) userinfo = await self._fetch_userinfo(token)
@ -846,13 +855,14 @@ class OidcProvider:
logger.exception("Could not fetch userinfo") logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e)) self._sso_handler.render_error(request, "fetch_error", str(e))
return return
elif id_token is not None:
userinfo = UserInfo(id_token)
else: else:
try: logger.error("Missing id_token in token response")
userinfo = await self._parse_id_token(token, nonce=session_data.nonce) self._sso_handler.render_error(
except Exception as e: request, "invalid_token", "Missing id_token in token response"
logger.exception("Invalid id_token") )
self._sso_handler.render_error(request, "invalid_token", str(e)) return
return
# first check if we're doing a UIA # first check if we're doing a UIA
if session_data.ui_auth_session_id: if session_data.ui_auth_session_id:
@ -884,7 +894,7 @@ class OidcProvider:
# Call the mapper to register/login the user # Call the mapper to register/login the user
try: try:
await self._complete_oidc_login( await self._complete_oidc_login(
userinfo, token, request, session_data.client_redirect_url userinfo, token, request, session_data.client_redirect_url, sid
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
@ -896,6 +906,7 @@ class OidcProvider:
token: Token, token: Token,
request: SynapseRequest, request: SynapseRequest,
client_redirect_url: str, client_redirect_url: str,
sid: Optional[str],
) -> None: ) -> None:
"""Given a UserInfo response, complete the login flow """Given a UserInfo response, complete the login flow
@ -1008,6 +1019,7 @@ class OidcProvider:
oidc_response_to_user_attributes, oidc_response_to_user_attributes,
grandfather_existing_users, grandfather_existing_users,
extra_attributes, extra_attributes,
auth_provider_session_id=sid,
) )
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:

View file

@ -746,6 +746,7 @@ class RegistrationHandler:
is_appservice_ghost: bool = False, is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]: ) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token. """Register a device for a user and generate an access token.
@ -756,9 +757,9 @@ class RegistrationHandler:
device_id: The device ID to check, or None to generate a new one. device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device. initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the auth_provider_id: The SSO IdP the user used, if any.
prometheus metrics).
should_issue_refresh_token: Whether it should also issue a refresh token should_issue_refresh_token: Whether it should also issue a refresh token
auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns: Returns:
Tuple of device ID, access token, access token expiration time and refresh token Tuple of device ID, access token, access token expiration time and refresh token
""" """
@ -769,6 +770,8 @@ class RegistrationHandler:
is_guest=is_guest, is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost, is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
login_counter.labels( login_counter.labels(
@ -791,6 +794,8 @@ class RegistrationHandler:
is_guest: bool = False, is_guest: bool = False,
is_appservice_ghost: bool = False, is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> LoginDict: ) -> LoginDict:
"""Helper for register_device """Helper for register_device
@ -822,7 +827,11 @@ class RegistrationHandler:
refresh_token_id = None refresh_token_id = None
registered_device_id = await self.device_handler.check_device_registered( registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id,
device_id,
initial_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
if is_guest: if is_guest:
assert access_token_expiry is None assert access_token_expiry is None

View file

@ -365,6 +365,7 @@ class SsoHandler:
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None, extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
) -> None: ) -> None:
""" """
Given an SSO ID, retrieve the user ID for it and possibly register the user. Given an SSO ID, retrieve the user ID for it and possibly register the user.
@ -415,6 +416,8 @@ class SsoHandler:
extra_login_attributes: An optional dictionary of extra extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response. attributes to be provided to the client in the login response.
auth_provider_session_id: An optional session ID from the IdP.
Raises: Raises:
MappingException if there was a problem mapping the response to a user. MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user RedirectException: if the mapping provider needs to redirect the user
@ -490,6 +493,7 @@ class SsoHandler:
client_redirect_url, client_redirect_url,
extra_login_attributes, extra_login_attributes,
new_user=new_user, new_user=new_user,
auth_provider_session_id=auth_provider_session_id,
) )
async def _call_attribute_mapper( async def _call_attribute_mapper(

View file

@ -626,6 +626,7 @@ class ModuleApi:
user_id: str, user_id: str,
duration_in_ms: int = (2 * 60 * 1000), duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "", auth_provider_id: str = "",
auth_provider_session_id: Optional[str] = None,
) -> str: ) -> str:
"""Generate a login token suitable for m.login.token authentication """Generate a login token suitable for m.login.token authentication
@ -643,6 +644,7 @@ class ModuleApi:
return self._hs.get_macaroon_generator().generate_short_term_login_token( return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, user_id,
auth_provider_id, auth_provider_id,
auth_provider_session_id,
duration_in_ms, duration_in_ms,
) )

View file

@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest, is_guest,
is_appservice_ghost, is_appservice_ghost,
should_issue_refresh_token, should_issue_refresh_token,
auth_provider_id,
auth_provider_session_id,
): ):
""" """
Args: Args:
@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest, "is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost, "is_appservice_ghost": is_appservice_ghost,
"should_issue_refresh_token": should_issue_refresh_token, "should_issue_refresh_token": should_issue_refresh_token,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
} }
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"] is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"] is_appservice_ghost = content["is_appservice_ghost"]
should_issue_refresh_token = content["should_issue_refresh_token"] should_issue_refresh_token = content["should_issue_refresh_token"]
auth_provider_id = content["auth_provider_id"]
auth_provider_session_id = content["auth_provider_session_id"]
res = await self.registration_handler.register_device_inner( res = await self.registration_handler.register_device_inner(
user_id, user_id,
@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest, is_guest,
is_appservice_ghost=is_appservice_ghost, is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
) )
return 200, res return 200, res

View file

@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet):
ratelimit: bool = True, ratelimit: bool = True,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> LoginResponse: ) -> LoginResponse:
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on actually login them in (e.g. create devices). This gets called on
@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False. exist. Defaults to False.
ratelimit: Whether to ratelimit the login request. ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the auth_provider_id: The SSO IdP the user used, if any.
prometheus metrics).
should_issue_refresh_token: True if this login should issue should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token. a refresh token alongside the access token.
auth_provider_session_id: The session ID got during login from the SSO IdP.
Returns: Returns:
result: Dictionary of account information after successful login. result: Dictionary of account information after successful login.
@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet):
initial_display_name, initial_display_name,
auth_provider_id=auth_provider_id, auth_provider_id=auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=auth_provider_session_id,
) )
result = LoginResponse( result = LoginResponse(
@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler._sso_login_callback, self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id, auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
) )
async def _do_jwt_login( async def _do_jwt_login(

View file

@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices} return {d["device_id"]: d for d in devices}
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
auth_provider_id: The SSO IdP ID as defined in the server config
auth_provider_session_id: The session ID within the IdP
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
)
@trace @trace
async def get_device_updates_by_remote( async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int self, destination: str, from_stream_id: int, limit: int
@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
async def store_device( async def store_device(
self, user_id: str, device_id: str, initial_device_display_name: Optional[str] self,
user_id: str,
device_id: str,
initial_device_display_name: Optional[str],
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> bool: ) -> bool:
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device device_id: id of device
initial_device_display_name: initial displayname of the device. initial_device_display_name: initial displayname of the device.
Ignored if device exists. Ignored if device exists.
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns: Returns:
Whether the device was inserted or an existing device existed with that ID. Whether the device was inserted or an existing device existed with that ID.
@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden: if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
if auth_provider_id and auth_provider_session_id:
await self.db_pool.simple_insert(
"device_auth_providers",
values={
"user_id": user_id,
"device_id": device_id,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
desc="store_device_auth_provider",
)
self.device_id_exists_cache.set(key, True) self.device_id_exists_cache.set(key, True)
return inserted return inserted
except StoreError: except StoreError:
@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
) )
self.db_pool.simple_delete_many_txn(
txn,
table="device_auth_providers",
column="device_id",
values=device_ids,
keyvalues={"user_id": user_id},
)
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids: for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))

View file

@ -0,0 +1,27 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Track the auth provider used by each login as well as the session ID
CREATE TABLE device_auth_providers (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
auth_provider_id TEXT NOT NULL,
auth_provider_session_id TEXT NOT NULL
);
CREATE INDEX device_auth_providers_devices
ON device_auth_providers (user_id, device_id);
CREATE INDEX device_auth_providers_sessions
ON device_auth_providers (auth_provider_id, auth_provider_session_id);

View file

@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", 5000 self.user1, "", duration_in_ms=5000
) )
res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual(self.user1, res.user_id) self.assertEqual(self.user1, res.user_id)
@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", 5000 self.user1, "", duration_in_ms=5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
def _get_macaroon(self): def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", 5000 self.user1, "", duration_in_ms=5000
) )
return pymacaroons.Macaroon.deserialize(token) return pymacaroons.Macaroon.deserialize(token)

View file

@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True "@test_user:test",
"cas",
request,
"redirect_uri",
None,
new_user=True,
auth_provider_session_id=None,
) )
def test_map_cas_user_to_existing_user(self): def test_map_cas_user_to_existing_user(self):
@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False "@test_user:test",
"cas",
request,
"redirect_uri",
None,
new_user=False,
auth_provider_session_id=None,
) )
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
) )
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False "@test_user:test",
"cas",
request,
"redirect_uri",
None,
new_user=False,
auth_provider_session_id=None,
) )
def test_map_cas_user_to_invalid_localpart(self): def test_map_cas_user_to_invalid_localpart(self):
@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True "@f=c3=b6=c3=b6:test",
"cas",
request,
"redirect_uri",
None,
new_user=True,
auth_provider_session_id=None,
) )
@override_config( @override_config(
@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True "@test_user:test",
"cas",
request,
"redirect_uri",
None,
new_user=True,
auth_provider_session_id=None,
) )

View file

@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
with patch.object(self.provider, "load_metadata", patched_load_metadata): with patch.object(self.provider, "load_metadata", patched_load_metadata):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError) self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self): def test_validate_config(self):
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True expected_user_id,
"oidc",
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
) )
self.provider._exchange_code.assert_called_once_with(code) self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider._fetch_userinfo.reset_mock() self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching # With userinfo fetching
self.provider._scopes = [] # do not ask the "openid" scope self.provider._user_profile_method = "userinfo_endpoint"
token = {
"type": "bearer",
"access_token": "access_token",
}
self.provider._exchange_code = simple_async_mock(return_value=token)
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False expected_user_id,
"oidc",
request,
client_redirect_url,
None,
new_user=False,
auth_provider_session_id=None,
) )
self.provider._exchange_code.assert_called_once_with(code) self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called() self.provider._parse_id_token.assert_not_called()
self.provider._fetch_userinfo.assert_called_once_with(token) self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called() self.render_error.assert_not_called()
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
token = {
"type": "bearer",
"access_token": "access_token",
"id_token": "id_token",
}
id_token = {
"sid": "abcdefgh",
}
self.provider._parse_id_token = simple_async_mock(return_value=id_token)
self.provider._exchange_code = simple_async_mock(return_value=token)
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id,
"oidc",
request,
client_redirect_url,
None,
new_user=False,
auth_provider_session_id=id_token["sid"],
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url, client_redirect_url,
{"phone": "1234567"}, {"phone": "1234567"},
new_user=True, new_user=True,
auth_provider_session_id=None,
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "oidc", ANY, ANY, None, new_user=True "@test_user:test",
"oidc",
ANY,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True "@test_user_2:test",
"oidc",
ANY,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), "oidc", ANY, ANY, None, new_user=False user.to_string(),
"oidc",
ANY,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), "oidc", ANY, ANY, None, new_user=False user.to_string(),
"oidc",
ANY,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), "oidc", ANY, ANY, None, new_user=False user.to_string(),
"oidc",
ANY,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False "@TEST_USER_2:test",
"oidc",
ANY,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True "@test_user1:test",
"oidc",
ANY,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@tester:test", "oidc", ANY, ANY, None, new_user=True "@tester:test",
"oidc",
ANY,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
) )
@override_config( @override_config(
@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@tester:test", "oidc", ANY, ANY, None, new_user=True "@tester:test",
"oidc",
ANY,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
) )
@override_config( @override_config(
@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
provider = handler._providers["oidc"] provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={}) provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._parse_id_token = simple_async_mock(return_value=userinfo)
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo)

View file

@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True "@test_user:test",
"saml",
request,
"redirect_uri",
None,
new_user=True,
auth_provider_session_id=None,
) )
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "saml", request, "", None, new_user=False "@test_user:test",
"saml",
request,
"",
None,
new_user=False,
auth_provider_session_id=None,
) )
# Subsequent calls should map to the same mxid. # Subsequent calls should map to the same mxid.
@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "") self.handler._handle_authn_response(request, saml_response, "")
) )
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "saml", request, "", None, new_user=False "@test_user:test",
"saml",
request,
"",
None,
new_user=False,
auth_provider_session_id=None,
) )
def test_map_saml_response_to_invalid_localpart(self): def test_map_saml_response_to_invalid_localpart(self):
@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead. # test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", "saml", request, "", None, new_user=True "@test_user1:test",
"saml",
request,
"",
None,
new_user=True,
auth_provider_session_id=None,
) )
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected # check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True "@test_user:test",
"saml",
request,
"redirect_uri",
None,
new_user=True,
auth_provider_session_id=None,
) )