From 3ab9e6d524960784f6f108e57ac86a921912ea84 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 21 Mar 2024 18:49:44 +0100 Subject: [PATCH] OIDC: try to JWT decode userinfo response if JSON parsing failed (#16972) --- changelog.d/16972.feature | 1 + synapse/handlers/oidc.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 changelog.d/16972.feature diff --git a/changelog.d/16972.feature b/changelog.d/16972.feature new file mode 100644 index 000000000..0f28cbbcd --- /dev/null +++ b/changelog.d/16972.feature @@ -0,0 +1 @@ +OIDC: try to JWT decode userinfo response if JSON parsing failed. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index ba67cc476..ab28dc800 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -829,14 +829,38 @@ class OidcProvider: logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() - resp = await self._http_client.get_json( + resp = await self._http_client.request( + "GET", metadata["userinfo_endpoint"], - headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, + headers=Headers( + {"Authorization": ["Bearer {}".format(token["access_token"])]} + ), ) - logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + body = await readBody(resp) - return UserInfo(resp) + content_type_headers = resp.headers.getRawHeaders("Content-Type") + assert content_type_headers + # We use `startswith` because the header value can contain the `charset` parameter + # even if it is useless, and Twisted doesn't take care of that for us. + if content_type_headers[0].startswith("application/jwt"): + alg_values = metadata.get( + "id_token_signing_alg_values_supported", ["RS256"] + ) + jwt = JsonWebToken(alg_values) + jwk_set = await self.load_jwks() + try: + decoded_resp = jwt.decode(body, key=jwk_set) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + decoded_resp = jwt.decode(body, key=jwk_set) + else: + decoded_resp = json_decoder.decode(body.decode("utf-8")) + + logger.debug("Retrieved user info from userinfo endpoint: %r", decoded_resp) + + return UserInfo(decoded_resp) async def _verify_jwt( self,