mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 07:23:48 +01:00
Support OIDC backchannel logouts (#11414)
If configured an OIDC IdP can log a user's session out of Synapse when they log out of the identity provider. The IdP sends a request directly to Synapse (and must be configured with an endpoint) when a user logs out.
This commit is contained in:
parent
15bdb0da52
commit
cc3a52b33d
13 changed files with 960 additions and 66 deletions
1
changelog.d/11414.feature
Normal file
1
changelog.d/11414.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support back-channel logouts from OpenID Connect providers.
|
|
@ -49,6 +49,13 @@ setting in your configuration file.
|
||||||
See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as
|
See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as
|
||||||
the text below for example configurations for specific providers.
|
the text below for example configurations for specific providers.
|
||||||
|
|
||||||
|
## OIDC Back-Channel Logout
|
||||||
|
|
||||||
|
Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications.
|
||||||
|
|
||||||
|
This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session.
|
||||||
|
This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout`
|
||||||
|
|
||||||
## Sample configs
|
## Sample configs
|
||||||
|
|
||||||
Here are a few configs for providers that should work with Synapse.
|
Here are a few configs for providers that should work with Synapse.
|
||||||
|
@ -123,6 +130,9 @@ oidc_providers:
|
||||||
|
|
||||||
[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
|
[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
|
||||||
|
|
||||||
|
Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak.
|
||||||
|
This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak.
|
||||||
|
|
||||||
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
|
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
|
||||||
|
|
||||||
1. Click `Clients` in the sidebar and click `Create`
|
1. Click `Clients` in the sidebar and click `Create`
|
||||||
|
@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
|
||||||
| Client Protocol | `openid-connect` |
|
| Client Protocol | `openid-connect` |
|
||||||
| Access Type | `confidential` |
|
| Access Type | `confidential` |
|
||||||
| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
|
| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
|
||||||
|
| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` |
|
||||||
|
| Backchannel Logout Session Required (optional) | `On` |
|
||||||
|
|
||||||
5. Click `Save`
|
5. Click `Save`
|
||||||
6. On the Credentials tab, update the fields:
|
6. On the Credentials tab, update the fields:
|
||||||
|
@ -167,7 +179,9 @@ oidc_providers:
|
||||||
config:
|
config:
|
||||||
localpart_template: "{{ user.preferred_username }}"
|
localpart_template: "{{ user.preferred_username }}"
|
||||||
display_name_template: "{{ user.name }}"
|
display_name_template: "{{ user.name }}"
|
||||||
|
backchannel_logout_enabled: true # Optional
|
||||||
```
|
```
|
||||||
|
|
||||||
### Auth0
|
### Auth0
|
||||||
|
|
||||||
[Auth0][auth0] is a hosted SaaS IdP solution.
|
[Auth0][auth0] is a hosted SaaS IdP solution.
|
||||||
|
|
|
@ -3021,6 +3021,15 @@ Options for each entry include:
|
||||||
which is set to the claims returned by the UserInfo Endpoint and/or
|
which is set to the claims returned by the UserInfo Endpoint and/or
|
||||||
in the ID Token.
|
in the ID Token.
|
||||||
|
|
||||||
|
* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications.
|
||||||
|
Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`.
|
||||||
|
Defaults to `false`.
|
||||||
|
|
||||||
|
* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the
|
||||||
|
`sub` claim matches the subject claim received during login. This check can be disabled by setting
|
||||||
|
this to `true`. Defaults to `false`.
|
||||||
|
|
||||||
|
You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`.
|
||||||
|
|
||||||
It is possible to configure Synapse to only allow logins if certain attributes
|
It is possible to configure Synapse to only allow logins if certain attributes
|
||||||
match particular values in the OIDC userinfo. The requirements can be listed under
|
match particular values in the OIDC userinfo. The requirements can be listed under
|
||||||
|
|
|
@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"userinfo_endpoint": {"type": "string"},
|
"userinfo_endpoint": {"type": "string"},
|
||||||
"jwks_uri": {"type": "string"},
|
"jwks_uri": {"type": "string"},
|
||||||
"skip_verification": {"type": "boolean"},
|
"skip_verification": {"type": "boolean"},
|
||||||
|
"backchannel_logout_enabled": {"type": "boolean"},
|
||||||
|
"backchannel_logout_ignore_sub": {"type": "boolean"},
|
||||||
"user_profile_method": {
|
"user_profile_method": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["auto", "userinfo_endpoint"],
|
"enum": ["auto", "userinfo_endpoint"],
|
||||||
|
@ -292,6 +294,10 @@ def _parse_oidc_config_dict(
|
||||||
token_endpoint=oidc_config.get("token_endpoint"),
|
token_endpoint=oidc_config.get("token_endpoint"),
|
||||||
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
|
userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
|
||||||
jwks_uri=oidc_config.get("jwks_uri"),
|
jwks_uri=oidc_config.get("jwks_uri"),
|
||||||
|
backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False),
|
||||||
|
backchannel_logout_ignore_sub=oidc_config.get(
|
||||||
|
"backchannel_logout_ignore_sub", False
|
||||||
|
),
|
||||||
skip_verification=oidc_config.get("skip_verification", False),
|
skip_verification=oidc_config.get("skip_verification", False),
|
||||||
user_profile_method=oidc_config.get("user_profile_method", "auto"),
|
user_profile_method=oidc_config.get("user_profile_method", "auto"),
|
||||||
allow_existing_users=oidc_config.get("allow_existing_users", False),
|
allow_existing_users=oidc_config.get("allow_existing_users", False),
|
||||||
|
@ -368,6 +374,12 @@ class OidcProviderConfig:
|
||||||
# "openid" scope is used.
|
# "openid" scope is used.
|
||||||
jwks_uri: Optional[str]
|
jwks_uri: Optional[str]
|
||||||
|
|
||||||
|
# Whether Synapse should react to backchannel logouts
|
||||||
|
backchannel_logout_enabled: bool
|
||||||
|
|
||||||
|
# Whether Synapse should ignore the `sub` claim in backchannel logouts or not.
|
||||||
|
backchannel_logout_ignore_sub: bool
|
||||||
|
|
||||||
# Whether to skip metadata verification
|
# Whether to skip metadata verification
|
||||||
skip_verification: bool
|
skip_verification: bool
|
||||||
|
|
||||||
|
|
|
@ -12,14 +12,28 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import binascii
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
import unpaddedbase64
|
||||||
from authlib.common.security import generate_token
|
from authlib.common.security import generate_token
|
||||||
from authlib.jose import JsonWebToken, jwt
|
from authlib.jose import JsonWebToken, JWTClaims
|
||||||
|
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
|
||||||
from authlib.oauth2.auth import ClientAuth
|
from authlib.oauth2.auth import ClientAuth
|
||||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||||
from authlib.oidc.core import CodeIDToken, UserInfo
|
from authlib.oidc.core import CodeIDToken, UserInfo
|
||||||
|
@ -35,9 +49,12 @@ from typing_extensions import TypedDict
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
|
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
|
from synapse.http.server import finish_request
|
||||||
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
|
@ -88,6 +105,8 @@ class Token(TypedDict):
|
||||||
#: there is no real point of doing this in our case.
|
#: there is no real point of doing this in our case.
|
||||||
JWK = Dict[str, str]
|
JWK = Dict[str, str]
|
||||||
|
|
||||||
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
|
||||||
#: A JWK Set, as per RFC7517 sec 5.
|
#: A JWK Set, as per RFC7517 sec 5.
|
||||||
class JWKS(TypedDict):
|
class JWKS(TypedDict):
|
||||||
|
@ -247,6 +266,80 @@ class OidcHandler:
|
||||||
|
|
||||||
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
||||||
|
|
||||||
|
async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
|
||||||
|
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
||||||
|
|
||||||
|
This extracts the logout_token from the request and tries to figure out
|
||||||
|
which OpenID Provider it is comming from. This works by matching the iss claim
|
||||||
|
with the issuer and the aud claim with the client_id.
|
||||||
|
|
||||||
|
Since at this point we don't know who signed the JWT, we can't just
|
||||||
|
decode it using authlib since it will always verifies the signature. We
|
||||||
|
have to decode it manually without validating the signature. The actual JWT
|
||||||
|
verification is done in the `OidcProvider.handler_backchannel_logout` method,
|
||||||
|
once we figured out which provider sent the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the incoming request from the browser.
|
||||||
|
"""
|
||||||
|
logout_token = parse_string(request, "logout_token")
|
||||||
|
if logout_token is None:
|
||||||
|
raise SynapseError(400, "Missing logout_token in request")
|
||||||
|
|
||||||
|
# A JWT looks like this:
|
||||||
|
# header.payload.signature
|
||||||
|
# where all parts are encoded with urlsafe base64.
|
||||||
|
# The aud and iss claims we care about are in the payload part, which
|
||||||
|
# is a JSON object.
|
||||||
|
try:
|
||||||
|
# By destructuring the list after splitting, we ensure that we have
|
||||||
|
# exactly 3 segments
|
||||||
|
_, payload, _ = logout_token.split(".")
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(400, "Invalid logout_token in request")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload_bytes = unpaddedbase64.decode_base64(payload)
|
||||||
|
claims = json_decoder.decode(payload_bytes.decode("utf-8"))
|
||||||
|
except (json.JSONDecodeError, binascii.Error, UnicodeError):
|
||||||
|
raise SynapseError(400, "Invalid logout_token payload in request")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Let's extract the iss and aud claims
|
||||||
|
iss = claims["iss"]
|
||||||
|
aud = claims["aud"]
|
||||||
|
# The aud claim can be either a string or a list of string. Here we
|
||||||
|
# normalize it as a list of strings.
|
||||||
|
if isinstance(aud, str):
|
||||||
|
aud = [aud]
|
||||||
|
|
||||||
|
# Check that we have the right types for the aud and the iss claims
|
||||||
|
if not isinstance(iss, str) or not isinstance(aud, list):
|
||||||
|
raise TypeError()
|
||||||
|
for a in aud:
|
||||||
|
if not isinstance(a, str):
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
|
# At this point we properly checked both claims types
|
||||||
|
issuer: str = iss
|
||||||
|
audience: List[str] = aud
|
||||||
|
except (TypeError, KeyError):
|
||||||
|
raise SynapseError(400, "Invalid issuer/audience in logout_token")
|
||||||
|
|
||||||
|
# Now that we know the audience and the issuer, we can figure out from
|
||||||
|
# what provider it is coming from
|
||||||
|
oidc_provider: Optional[OidcProvider] = None
|
||||||
|
for provider in self._providers.values():
|
||||||
|
if provider.issuer == issuer and provider.client_id in audience:
|
||||||
|
oidc_provider = provider
|
||||||
|
break
|
||||||
|
|
||||||
|
if oidc_provider is None:
|
||||||
|
raise SynapseError(400, "Could not find the OP that issued this event")
|
||||||
|
|
||||||
|
# Ask the provider to handle the logout request.
|
||||||
|
await oidc_provider.handle_backchannel_logout(request, logout_token)
|
||||||
|
|
||||||
|
|
||||||
class OidcError(Exception):
|
class OidcError(Exception):
|
||||||
"""Used to catch errors when calling the token_endpoint"""
|
"""Used to catch errors when calling the token_endpoint"""
|
||||||
|
@ -342,6 +435,7 @@ class OidcProvider:
|
||||||
self.idp_brand = provider.idp_brand
|
self.idp_brand = provider.idp_brand
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
|
||||||
|
@ -400,6 +494,41 @@ class OidcProvider:
|
||||||
# If we're not using userinfo, we need a valid jwks to validate the ID token
|
# If we're not using userinfo, we need a valid jwks to validate the ID token
|
||||||
m.validate_jwks_uri()
|
m.validate_jwks_uri()
|
||||||
|
|
||||||
|
if self._config.backchannel_logout_enabled:
|
||||||
|
if not m.get("backchannel_logout_supported", False):
|
||||||
|
logger.warning(
|
||||||
|
"OIDC Back-Channel Logout is enabled for issuer %r"
|
||||||
|
"but it does not advertise support for it",
|
||||||
|
self.issuer,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif not m.get("backchannel_logout_session_supported", False):
|
||||||
|
logger.warning(
|
||||||
|
"OIDC Back-Channel Logout is enabled and supported "
|
||||||
|
"by issuer %r but it might not send a session ID with "
|
||||||
|
"logout tokens, which is required for the logouts to work",
|
||||||
|
self.issuer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._config.backchannel_logout_ignore_sub:
|
||||||
|
# If OIDC backchannel logouts are enabled, the provider mapping provider
|
||||||
|
# should use the `sub` claim. We verify that by mapping a dumb user and
|
||||||
|
# see if we get back the sub claim
|
||||||
|
user = UserInfo({"sub": "thisisasubject"})
|
||||||
|
try:
|
||||||
|
subject = self._user_mapping_provider.get_remote_user_id(user)
|
||||||
|
if subject != user["sub"]:
|
||||||
|
raise ValueError("Unexpected subject")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
|
||||||
|
"but it looks like the configured `user_mapping_provider` "
|
||||||
|
"does not use the `sub` claim as subject. If it is the case, "
|
||||||
|
"and you want Synapse to ignore the `sub` claim in OIDC "
|
||||||
|
"Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
|
||||||
|
"to `true` in the issuer config."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _uses_userinfo(self) -> bool:
|
def _uses_userinfo(self) -> bool:
|
||||||
"""Returns True if the ``userinfo_endpoint`` should be used.
|
"""Returns True if the ``userinfo_endpoint`` should be used.
|
||||||
|
@ -415,6 +544,16 @@ class OidcProvider:
|
||||||
or self._user_profile_method == "userinfo_endpoint"
|
or self._user_profile_method == "userinfo_endpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def issuer(self) -> str:
|
||||||
|
"""The issuer identifying this provider."""
|
||||||
|
return self._config.issuer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_id(self) -> str:
|
||||||
|
"""The client_id used when interacting with this provider."""
|
||||||
|
return self._config.client_id
|
||||||
|
|
||||||
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
|
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
|
||||||
"""Return the provider metadata.
|
"""Return the provider metadata.
|
||||||
|
|
||||||
|
@ -662,6 +801,59 @@ class OidcProvider:
|
||||||
|
|
||||||
return UserInfo(resp)
|
return UserInfo(resp)
|
||||||
|
|
||||||
|
async def _verify_jwt(
|
||||||
|
self,
|
||||||
|
alg_values: List[str],
|
||||||
|
token: str,
|
||||||
|
claims_cls: Type[C],
|
||||||
|
claims_options: Optional[dict] = None,
|
||||||
|
claims_params: Optional[dict] = None,
|
||||||
|
) -> C:
|
||||||
|
"""Decode and validate a JWT, re-fetching the JWKS as needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alg_values: list of `alg` values allowed when verifying the JWT.
|
||||||
|
token: the JWT.
|
||||||
|
claims_cls: the JWTClaims class to use to validate the claims.
|
||||||
|
claims_options: dict of options passed to the `claims_cls` constructor.
|
||||||
|
claims_params: dict of params passed to the `claims_cls` constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded claims in the JWT.
|
||||||
|
"""
|
||||||
|
jwt = JsonWebToken(alg_values)
|
||||||
|
|
||||||
|
logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
|
||||||
|
|
||||||
|
# Try to decode the keys in cache first, then retry by forcing the keys
|
||||||
|
# to be reloaded
|
||||||
|
jwk_set = await self.load_jwks()
|
||||||
|
try:
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
key=jwk_set,
|
||||||
|
claims_cls=claims_cls,
|
||||||
|
claims_options=claims_options,
|
||||||
|
claims_params=claims_params,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
logger.info("Reloading JWKS after decode error")
|
||||||
|
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
key=jwk_set,
|
||||||
|
claims_cls=claims_cls,
|
||||||
|
claims_options=claims_options,
|
||||||
|
claims_params=claims_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
|
||||||
|
|
||||||
|
claims.validate(
|
||||||
|
now=self._clock.time(), leeway=120
|
||||||
|
) # allows 2 min of clock skew
|
||||||
|
return claims
|
||||||
|
|
||||||
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
||||||
"""Return an instance of UserInfo from token's ``id_token``.
|
"""Return an instance of UserInfo from token's ``id_token``.
|
||||||
|
|
||||||
|
@ -675,13 +867,13 @@ class OidcProvider:
|
||||||
The decoded claims in the ID token.
|
The decoded claims in the ID token.
|
||||||
"""
|
"""
|
||||||
id_token = token.get("id_token")
|
id_token = token.get("id_token")
|
||||||
logger.debug("Attempting to decode JWT id_token %r", id_token)
|
|
||||||
|
|
||||||
# That has been theoritically been checked by the caller, so even though
|
# That has been theoritically been checked by the caller, so even though
|
||||||
# assertion are not enabled in production, it is mainly here to appease mypy
|
# assertion are not enabled in production, it is mainly here to appease mypy
|
||||||
assert id_token is not None
|
assert id_token is not None
|
||||||
|
|
||||||
metadata = await self.load_metadata()
|
metadata = await self.load_metadata()
|
||||||
|
|
||||||
claims_params = {
|
claims_params = {
|
||||||
"nonce": nonce,
|
"nonce": nonce,
|
||||||
"client_id": self._client_auth.client_id,
|
"client_id": self._client_auth.client_id,
|
||||||
|
@ -691,38 +883,17 @@ class OidcProvider:
|
||||||
# in the `id_token` that we can check against.
|
# in the `id_token` that we can check against.
|
||||||
claims_params["access_token"] = token["access_token"]
|
claims_params["access_token"] = token["access_token"]
|
||||||
|
|
||||||
|
claims_options = {"iss": {"values": [metadata["issuer"]]}}
|
||||||
|
|
||||||
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||||
jwt = JsonWebToken(alg_values)
|
|
||||||
|
|
||||||
claim_options = {"iss": {"values": [metadata["issuer"]]}}
|
claims = await self._verify_jwt(
|
||||||
|
alg_values=alg_values,
|
||||||
# Try to decode the keys in cache first, then retry by forcing the keys
|
token=id_token,
|
||||||
# to be reloaded
|
claims_cls=CodeIDToken,
|
||||||
jwk_set = await self.load_jwks()
|
claims_options=claims_options,
|
||||||
try:
|
claims_params=claims_params,
|
||||||
claims = jwt.decode(
|
)
|
||||||
id_token,
|
|
||||||
key=jwk_set,
|
|
||||||
claims_cls=CodeIDToken,
|
|
||||||
claims_options=claim_options,
|
|
||||||
claims_params=claims_params,
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
logger.info("Reloading JWKS after decode error")
|
|
||||||
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
|
||||||
claims = jwt.decode(
|
|
||||||
id_token,
|
|
||||||
key=jwk_set,
|
|
||||||
claims_cls=CodeIDToken,
|
|
||||||
claims_options=claim_options,
|
|
||||||
claims_params=claims_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Decoded id_token JWT %r; validating", claims)
|
|
||||||
|
|
||||||
claims.validate(
|
|
||||||
now=self._clock.time(), leeway=120
|
|
||||||
) # allows 2 min of clock skew
|
|
||||||
|
|
||||||
return claims
|
return claims
|
||||||
|
|
||||||
|
@ -1043,6 +1214,146 @@ class OidcProvider:
|
||||||
# to be strings.
|
# to be strings.
|
||||||
return str(remote_user_id)
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
async def handle_backchannel_logout(
|
||||||
|
self, request: SynapseRequest, logout_token: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
||||||
|
|
||||||
|
The OIDC Provider posts a logout token to this endpoint when a user
|
||||||
|
session ends. That token is a JWT signed with the same keys as
|
||||||
|
ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
|
||||||
|
validate the JWT and figure out what session to end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The request to respond to
|
||||||
|
logout_token: The logout token (a JWT) extracted from the request body
|
||||||
|
"""
|
||||||
|
# Back-Channel Logout can be disabled in the config, hence this check.
|
||||||
|
# This is not that important for now since Synapse is registered
|
||||||
|
# manually to the OP, so not specifying the backchannel-logout URI is
|
||||||
|
# as effective than disabling it here. It might make more sense if we
|
||||||
|
# support dynamic registration in Synapse at some point.
|
||||||
|
if not self._config.backchannel_logout_enabled:
|
||||||
|
logger.warning(
|
||||||
|
f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: this responds with a 400 status code, which is what the OIDC
|
||||||
|
# Back-Channel Logout spec expects, but spec also suggests answering with
|
||||||
|
# a JSON object, with the `error` and `error_description` fields set, which
|
||||||
|
# we are not doing here.
|
||||||
|
# See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
|
||||||
|
raise SynapseError(
|
||||||
|
400, "OpenID Connect Back-Channel Logout is disabled for this provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = await self.load_metadata()
|
||||||
|
|
||||||
|
# As per OIDC Back-Channel Logout 1.0 sec. 2.4:
|
||||||
|
# A Logout Token MUST be signed and MAY also be encrypted. The same
|
||||||
|
# keys are used to sign and encrypt Logout Tokens as are used for ID
|
||||||
|
# Tokens. If the Logout Token is encrypted, it SHOULD replicate the
|
||||||
|
# iss (issuer) claim in the JWT Header Parameters, as specified in
|
||||||
|
# Section 5.3 of [JWT].
|
||||||
|
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
||||||
|
|
||||||
|
# As per sec. 2.6:
|
||||||
|
# 3. Validate the iss, aud, and iat Claims in the same way they are
|
||||||
|
# validated in ID Tokens.
|
||||||
|
# Which means the audience should contain Synapse's client_id and the
|
||||||
|
# issuer should be the IdP issuer
|
||||||
|
claims_options = {
|
||||||
|
"iss": {"values": [metadata["issuer"]]},
|
||||||
|
"aud": {"values": [self.client_id]},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = await self._verify_jwt(
|
||||||
|
alg_values=alg_values,
|
||||||
|
token=logout_token,
|
||||||
|
claims_cls=LogoutToken,
|
||||||
|
claims_options=claims_options,
|
||||||
|
)
|
||||||
|
except JoseError:
|
||||||
|
logger.exception("Invalid logout_token")
|
||||||
|
raise SynapseError(400, "Invalid logout_token")
|
||||||
|
|
||||||
|
# As per sec. 2.6:
|
||||||
|
# 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
|
||||||
|
# or both.
|
||||||
|
# 5. Verify that the Logout Token contains an events Claim whose
|
||||||
|
# value is JSON object containing the member name
|
||||||
|
# http://schemas.openid.net/event/backchannel-logout.
|
||||||
|
# 6. Verify that the Logout Token does not contain a nonce Claim.
|
||||||
|
# This is all verified by the LogoutToken claims class, so at this
|
||||||
|
# point the `sid` claim exists and is a string.
|
||||||
|
sid: str = claims.get("sid")
|
||||||
|
|
||||||
|
# If the `sub` claim was included in the logout token, we check that it matches
|
||||||
|
# that it matches the right user. We can have cases where the `sub` claim is not
|
||||||
|
# the ID saved in database, so we let admins disable this check in config.
|
||||||
|
sub: Optional[str] = claims.get("sub")
|
||||||
|
expected_user_id: Optional[str] = None
|
||||||
|
if sub is not None and not self._config.backchannel_logout_ignore_sub:
|
||||||
|
expected_user_id = await self._store.get_user_by_external_id(
|
||||||
|
self.idp_id, sub
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalidate any running user-mapping sessions, in-flight login tokens and
|
||||||
|
# active devices
|
||||||
|
await self._sso_handler.revoke_sessions_for_provider_session_id(
|
||||||
|
auth_provider_id=self.idp_id,
|
||||||
|
auth_provider_session_id=sid,
|
||||||
|
expected_user_id=expected_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
request.setResponseCode(200)
|
||||||
|
request.setHeader(b"Cache-Control", b"no-cache, no-store")
|
||||||
|
request.setHeader(b"Pragma", b"no-cache")
|
||||||
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutToken(JWTClaims):
|
||||||
|
"""
|
||||||
|
Holds and verify claims of a logout token, as per
|
||||||
|
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
|
||||||
|
"""
|
||||||
|
|
||||||
|
REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
|
||||||
|
|
||||||
|
def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
|
||||||
|
"""Validate everything in claims payload."""
|
||||||
|
super().validate(now, leeway)
|
||||||
|
self.validate_sid()
|
||||||
|
self.validate_events()
|
||||||
|
self.validate_nonce()
|
||||||
|
|
||||||
|
def validate_sid(self) -> None:
|
||||||
|
"""Ensure the sid claim is present"""
|
||||||
|
sid = self.get("sid")
|
||||||
|
if not sid:
|
||||||
|
raise MissingClaimError("sid")
|
||||||
|
|
||||||
|
if not isinstance(sid, str):
|
||||||
|
raise InvalidClaimError("sid")
|
||||||
|
|
||||||
|
def validate_nonce(self) -> None:
|
||||||
|
"""Ensure the nonce claim is absent"""
|
||||||
|
if "nonce" in self:
|
||||||
|
raise InvalidClaimError("nonce")
|
||||||
|
|
||||||
|
def validate_events(self) -> None:
|
||||||
|
"""Ensure the events claim is present and with the right value"""
|
||||||
|
events = self.get("events")
|
||||||
|
if not events:
|
||||||
|
raise MissingClaimError("events")
|
||||||
|
|
||||||
|
if not isinstance(events, dict):
|
||||||
|
raise InvalidClaimError("events")
|
||||||
|
|
||||||
|
if "http://schemas.openid.net/event/backchannel-logout" not in events:
|
||||||
|
raise InvalidClaimError("events")
|
||||||
|
|
||||||
|
|
||||||
# number of seconds a newly-generated client secret should be valid for
|
# number of seconds a newly-generated client secret should be valid for
|
||||||
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
||||||
|
@ -1112,6 +1423,7 @@ class JwtClientSecret:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
||||||
)
|
)
|
||||||
|
jwt = JsonWebToken(header["alg"])
|
||||||
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
||||||
self._cached_secret_replacement_time = (
|
self._cached_secret_replacement_time = (
|
||||||
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
||||||
|
@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
|
||||||
emails: List[str]
|
emails: List[str]
|
||||||
|
|
||||||
|
|
||||||
C = TypeVar("C")
|
|
||||||
|
|
||||||
|
|
||||||
class OidcMappingProvider(Generic[C]):
|
class OidcMappingProvider(Generic[C]):
|
||||||
"""A mapping provider maps a UserInfo object to user attributes.
|
"""A mapping provider maps a UserInfo object to user attributes.
|
||||||
|
|
||||||
|
|
|
@ -191,6 +191,7 @@ class SsoHandler:
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
self._error_template = hs.config.sso.sso_error_template
|
self._error_template = hs.config.sso.sso_error_template
|
||||||
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
|
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
|
||||||
self._profile_handler = hs.get_profile_handler()
|
self._profile_handler = hs.get_profile_handler()
|
||||||
|
@ -1026,6 +1027,76 @@ class SsoHandler:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def revoke_sessions_for_provider_session_id(
|
||||||
|
self,
|
||||||
|
auth_provider_id: str,
|
||||||
|
auth_provider_session_id: str,
|
||||||
|
expected_user_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Revoke any devices and in-flight logins tied to a provider session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
|
"oidc" or "saml".
|
||||||
|
auth_provider_session_id: The session ID from the provider to logout
|
||||||
|
expected_user_id: The user we're expecting to logout. If set, it will ignore
|
||||||
|
sessions belonging to other users and log an error.
|
||||||
|
"""
|
||||||
|
# Invalidate any running user-mapping sessions
|
||||||
|
to_delete = []
|
||||||
|
for session_id, session in self._username_mapping_sessions.items():
|
||||||
|
if (
|
||||||
|
session.auth_provider_id == auth_provider_id
|
||||||
|
and session.auth_provider_session_id == auth_provider_session_id
|
||||||
|
):
|
||||||
|
to_delete.append(session_id)
|
||||||
|
|
||||||
|
for session_id in to_delete:
|
||||||
|
logger.info("Revoking mapping session %s", session_id)
|
||||||
|
del self._username_mapping_sessions[session_id]
|
||||||
|
|
||||||
|
# Invalidate any in-flight login tokens
|
||||||
|
await self._store.invalidate_login_tokens_by_session_id(
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch any device(s) in the store associated with the session ID.
|
||||||
|
devices = await self._store.get_devices_by_auth_provider_session_id(
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We have no guarantee that all the devices of that session are for the same
|
||||||
|
# `user_id`. Hence, we have to iterate over the list of devices and log them out
|
||||||
|
# one by one.
|
||||||
|
for device in devices:
|
||||||
|
user_id = device["user_id"]
|
||||||
|
device_id = device["device_id"]
|
||||||
|
|
||||||
|
# If the user_id associated with that device/session is not the one we got
|
||||||
|
# out of the `sub` claim, skip that device and show log an error.
|
||||||
|
if expected_user_id is not None and user_id != expected_user_id:
|
||||||
|
logger.error(
|
||||||
|
"Received a logout notification from SSO provider "
|
||||||
|
f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
|
||||||
|
f"a session ID ({auth_provider_session_id!r}) which belongs to "
|
||||||
|
f"{user_id!r}. This may happen when the SSO provider user mapper "
|
||||||
|
"uses something else than the standard attribute as mapping ID. "
|
||||||
|
"For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
|
||||||
|
"in the provider config if that is the case."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
auth_provider_id,
|
||||||
|
auth_provider_session_id,
|
||||||
|
)
|
||||||
|
await self._device_handler.delete_devices(user_id, [device_id])
|
||||||
|
|
||||||
|
|
||||||
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
||||||
"""Extract the session ID from the cookie
|
"""Extract the session ID from the cookie
|
||||||
|
|
|
@ -17,6 +17,9 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from synapse.rest.synapse.client.oidc.backchannel_logout_resource import (
|
||||||
|
OIDCBackchannelLogoutResource,
|
||||||
|
)
|
||||||
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
|
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -29,6 +32,7 @@ class OIDCResource(Resource):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
self.putChild(b"callback", OIDCCallbackResource(hs))
|
self.putChild(b"callback", OIDCCallbackResource(hs))
|
||||||
|
self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OIDCResource"]
|
__all__ = ["OIDCResource"]
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.http.server import DirectServeJsonResource
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCBackchannelLogoutResource(DirectServeJsonResource):
|
||||||
|
isLeaf = 1
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self._oidc_handler = hs.get_oidc_handler()
|
||||||
|
|
||||||
|
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||||
|
await self._oidc_handler.handle_backchannel_logout(request)
|
|
@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
self._clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def invalidate_login_tokens_by_session_id(
|
||||||
|
self, auth_provider_id: str, auth_provider_session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Invalidate login tokens with the given IdP session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: The SSO Identity Provider that the user authenticated with
|
||||||
|
to get this token
|
||||||
|
auth_provider_session_id: The session ID advertised by the SSO Identity
|
||||||
|
Provider
|
||||||
|
"""
|
||||||
|
await self.db_pool.simple_update(
|
||||||
|
table="login_tokens",
|
||||||
|
keyvalues={
|
||||||
|
"auth_provider_id": auth_provider_id,
|
||||||
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
|
},
|
||||||
|
updatevalues={"used_ts": self._clock.time_msec()},
|
||||||
|
desc="invalidate_login_tokens_by_session_id",
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def is_guest(self, user_id: str) -> bool:
|
async def is_guest(self, user_id: str) -> bool:
|
||||||
res = await self.db_pool.simple_select_one_onecol(
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import re
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -21,7 +22,7 @@ from twisted.web.resource import Resource
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||||
from synapse.rest.client import account, auth, devices, login, logout, register
|
from synapse.rest.client import account, auth, devices, login, logout, register
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
|
@ -32,8 +33,8 @@ from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.handlers.test_oidc import HAS_OIDC
|
from tests.handlers.test_oidc import HAS_OIDC
|
||||||
from tests.rest.client.utils import TEST_OIDC_CONFIG
|
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel, make_request
|
||||||
from tests.unittest import override_config, skip_unless
|
from tests.unittest import override_config, skip_unless
|
||||||
|
|
||||||
|
|
||||||
|
@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
{"refresh_token": refresh_token},
|
{"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_access_token_valid(self, access_token: str) -> bool:
|
|
||||||
"""
|
|
||||||
Checks whether an access token is valid, returning whether it is or not.
|
|
||||||
"""
|
|
||||||
code = self.make_request(
|
|
||||||
"GET", "/_matrix/client/v3/account/whoami", access_token=access_token
|
|
||||||
).code
|
|
||||||
|
|
||||||
# Either 200 or 401 is what we get back; anything else is a bug.
|
|
||||||
assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
|
|
||||||
|
|
||||||
return code == HTTPStatus.OK
|
|
||||||
|
|
||||||
def test_login_issue_refresh_token(self) -> None:
|
def test_login_issue_refresh_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
A login response should include a refresh_token only if asked.
|
A login response should include a refresh_token only if asked.
|
||||||
|
@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(59.0)
|
self.reactor.advance(59.0)
|
||||||
|
|
||||||
# Both tokens should still be valid.
|
# Both tokens should still be valid.
|
||||||
self.assertTrue(self.is_access_token_valid(refreshable_access_token))
|
self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
|
||||||
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
|
self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
# Advance to 61 s (just past 1 minute, the time of expiry)
|
# Advance to 61 s (just past 1 minute, the time of expiry)
|
||||||
self.reactor.advance(2.0)
|
self.reactor.advance(2.0)
|
||||||
|
|
||||||
# Only the non-refreshable token is still valid.
|
# Only the non-refreshable token is still valid.
|
||||||
self.assertFalse(self.is_access_token_valid(refreshable_access_token))
|
self.helper.whoami(
|
||||||
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
|
refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
# Advance to 599 s (just shy of 10 minutes, the time of expiry)
|
# Advance to 599 s (just shy of 10 minutes, the time of expiry)
|
||||||
self.reactor.advance(599.0 - 61.0)
|
self.reactor.advance(599.0 - 61.0)
|
||||||
|
|
||||||
# It's still the case that only the non-refreshable token is still valid.
|
# It's still the case that only the non-refreshable token is still valid.
|
||||||
self.assertFalse(self.is_access_token_valid(refreshable_access_token))
|
self.helper.whoami(
|
||||||
self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
|
refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
# Advance to 601 s (just past 10 minutes, the time of expiry)
|
# Advance to 601 s (just past 10 minutes, the time of expiry)
|
||||||
self.reactor.advance(2.0)
|
self.reactor.advance(2.0)
|
||||||
|
|
||||||
# Now neither token is valid.
|
# Now neither token is valid.
|
||||||
self.assertFalse(self.is_access_token_valid(refreshable_access_token))
|
self.helper.whoami(
|
||||||
self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
|
refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
self.helper.whoami(
|
||||||
|
nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
|
||||||
|
)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
|
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
|
||||||
|
@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||||
# and no refresh token
|
# and no refresh token
|
||||||
self.assertEqual(_table_length("access_tokens"), 0)
|
self.assertEqual(_table_length("access_tokens"), 0)
|
||||||
self.assertEqual(_table_length("refresh_tokens"), 0)
|
self.assertEqual(_table_length("refresh_tokens"), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def oidc_config(
|
||||||
|
id: str, with_localpart_template: bool, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Sample OIDC provider config used in backchannel logout tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: IDP ID for this provider
|
||||||
|
with_localpart_template: Set to `true` to have a default localpart_template in
|
||||||
|
the `user_mapping_provider` config and skip the user mapping session
|
||||||
|
**kwargs: rest of the config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
|
||||||
|
the HS config
|
||||||
|
"""
|
||||||
|
config: Dict[str, Any] = {
|
||||||
|
"idp_id": id,
|
||||||
|
"idp_name": id,
|
||||||
|
"issuer": TEST_OIDC_ISSUER,
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"scopes": ["openid"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if with_localpart_template:
|
||||||
|
config["user_mapping_provider"] = {
|
||||||
|
"config": {"localpart_template": "{{ user.sub }}"}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
config["user_mapping_provider"] = {"config": {}}
|
||||||
|
|
||||||
|
config.update(kwargs)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@skip_unless(HAS_OIDC, "Requires OIDC")
|
||||||
|
class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
account.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> Dict[str, Any]:
|
||||||
|
config = super().default_config()
|
||||||
|
|
||||||
|
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
|
||||||
|
# False, so synapse will see the requested uri as http://..., so using http in
|
||||||
|
# the public_baseurl stops Synapse trying to redirect to https.
|
||||||
|
config["public_baseurl"] = "http://synapse.test"
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
resource_dict = super().create_resource_dict()
|
||||||
|
resource_dict.update(build_synapse_client_resource_tree(self.hs))
|
||||||
|
return resource_dict
|
||||||
|
|
||||||
|
def submit_logout_token(self, logout_token: str) -> FakeChannel:
|
||||||
|
return self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_synapse/client/oidc/backchannel_logout",
|
||||||
|
content=f"logout_token={logout_token}",
|
||||||
|
content_is_form=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
id="oidc",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_simple_logout(self) -> None:
|
||||||
|
"""
|
||||||
|
Receiving a logout token should logout the user
|
||||||
|
"""
|
||||||
|
fake_oidc_server = self.helper.fake_oidc_server()
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
login_resp, first_grant = self.helper.login_via_oidc(
|
||||||
|
fake_oidc_server, user, with_sid=True
|
||||||
|
)
|
||||||
|
first_access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
login_resp, second_grant = self.helper.login_via_oidc(
|
||||||
|
fake_oidc_server, user, with_sid=True
|
||||||
|
)
|
||||||
|
second_access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
self.assertNotEqual(first_grant.sid, second_grant.sid)
|
||||||
|
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
|
||||||
|
|
||||||
|
# Logging out of the first session
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(first_grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
|
||||||
|
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Logging out of the second session
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(second_grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
id="oidc",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_logout_during_login(self) -> None:
|
||||||
|
"""
|
||||||
|
It should revoke login tokens when receiving a logout token
|
||||||
|
"""
|
||||||
|
fake_oidc_server = self.helper.fake_oidc_server()
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
# Get an authentication, and logout before submitting the logout token
|
||||||
|
client_redirect_url = "https://x"
|
||||||
|
userinfo = {"sub": user}
|
||||||
|
channel, grant = self.helper.auth_via_oidc(
|
||||||
|
fake_oidc_server,
|
||||||
|
userinfo,
|
||||||
|
client_redirect_url,
|
||||||
|
with_sid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# expect a confirmation page
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||||
|
|
||||||
|
# fish the matrix login token out of the body of the confirmation page
|
||||||
|
m = re.search(
|
||||||
|
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||||
|
channel.text_body,
|
||||||
|
)
|
||||||
|
assert m, channel.text_body
|
||||||
|
login_token = m.group(1)
|
||||||
|
|
||||||
|
# Submit a logout
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
# Now try to exchange the login token
|
||||||
|
channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"POST",
|
||||||
|
"/login",
|
||||||
|
content={"type": "m.login.token", "token": login_token},
|
||||||
|
)
|
||||||
|
# It should have failed
|
||||||
|
self.assertEqual(channel.code, 403)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
id="oidc",
|
||||||
|
with_localpart_template=False,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_logout_during_mapping(self) -> None:
|
||||||
|
"""
|
||||||
|
It should stop ongoing user mapping session when receiving a logout token
|
||||||
|
"""
|
||||||
|
fake_oidc_server = self.helper.fake_oidc_server()
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
# Get an authentication, and logout before submitting the logout token
|
||||||
|
client_redirect_url = "https://x"
|
||||||
|
userinfo = {"sub": user}
|
||||||
|
channel, grant = self.helper.auth_via_oidc(
|
||||||
|
fake_oidc_server,
|
||||||
|
userinfo,
|
||||||
|
client_redirect_url,
|
||||||
|
with_sid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expect a user mapping page
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||||
|
|
||||||
|
# We should have a user_mapping_session cookie
|
||||||
|
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
|
||||||
|
assert cookie_headers
|
||||||
|
cookies: Dict[str, str] = {}
|
||||||
|
for h in cookie_headers:
|
||||||
|
key, value = h.split(";")[0].split("=", maxsplit=1)
|
||||||
|
cookies[key] = value
|
||||||
|
|
||||||
|
user_mapping_session_id = cookies["username_mapping_session"]
|
||||||
|
|
||||||
|
# Getting that session should not raise
|
||||||
|
session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
|
||||||
|
self.assertIsNotNone(session)
|
||||||
|
|
||||||
|
# Submit a logout
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
# Now it should raise
|
||||||
|
with self.assertRaises(SynapseError):
|
||||||
|
self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
id="oidc",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_disabled(self) -> None:
|
||||||
|
"""
|
||||||
|
Receiving a logout token should do nothing if it is disabled in the config
|
||||||
|
"""
|
||||||
|
fake_oidc_server = self.helper.fake_oidc_server()
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
login_resp, grant = self.helper.login_via_oidc(
|
||||||
|
fake_oidc_server, user, with_sid=True
|
||||||
|
)
|
||||||
|
access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Logging out shouldn't work
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
|
# And the token should still be valid
|
||||||
|
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
id="oidc",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_no_sid(self) -> None:
|
||||||
|
"""
|
||||||
|
Receiving a logout token without `sid` during the login should do nothing
|
||||||
|
"""
|
||||||
|
fake_oidc_server = self.helper.fake_oidc_server()
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
login_resp, grant = self.helper.login_via_oidc(
|
||||||
|
fake_oidc_server, user, with_sid=False
|
||||||
|
)
|
||||||
|
access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Logging out shouldn't work
|
||||||
|
logout_token = fake_oidc_server.generate_logout_token(grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
|
# And the token should still be valid
|
||||||
|
self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_providers": [
|
||||||
|
oidc_config(
|
||||||
|
"first",
|
||||||
|
issuer="https://first-issuer.com/",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
),
|
||||||
|
oidc_config(
|
||||||
|
"second",
|
||||||
|
issuer="https://second-issuer.com/",
|
||||||
|
with_localpart_template=True,
|
||||||
|
backchannel_logout_enabled=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_multiple_providers(self) -> None:
|
||||||
|
"""
|
||||||
|
It should be able to distinguish login tokens from two different IdPs
|
||||||
|
"""
|
||||||
|
first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
|
||||||
|
second_server = self.helper.fake_oidc_server(
|
||||||
|
issuer="https://second-issuer.com/"
|
||||||
|
)
|
||||||
|
user = "john"
|
||||||
|
|
||||||
|
login_resp, first_grant = self.helper.login_via_oidc(
|
||||||
|
first_server, user, with_sid=True, idp_id="oidc-first"
|
||||||
|
)
|
||||||
|
first_access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
login_resp, second_grant = self.helper.login_via_oidc(
|
||||||
|
second_server, user, with_sid=True, idp_id="oidc-second"
|
||||||
|
)
|
||||||
|
second_access_token: str = login_resp["access_token"]
|
||||||
|
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
# `sid` in the fake providers are generated by a counter, so the first grant of
|
||||||
|
# each provider should give the same SID
|
||||||
|
self.assertEqual(first_grant.sid, second_grant.sid)
|
||||||
|
self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
|
||||||
|
|
||||||
|
# Logging out of the first session
|
||||||
|
logout_token = first_server.generate_logout_token(first_grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
|
||||||
|
self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Logging out of the second session
|
||||||
|
logout_token = second_server.generate_logout_token(second_grant)
|
||||||
|
channel = self.submit_logout_token(logout_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
|
@ -553,6 +553,34 @@ class RestHelper:
|
||||||
|
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
|
def whoami(
|
||||||
|
self,
|
||||||
|
access_token: str,
|
||||||
|
expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""Perform a 'whoami' request, which can be a quick way to check for access
|
||||||
|
token validity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: The user token to use during the request
|
||||||
|
expect_code: The return code to expect from attempting the whoami request
|
||||||
|
"""
|
||||||
|
channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
"account/whoami",
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
|
||||||
|
expect_code,
|
||||||
|
channel.code,
|
||||||
|
channel.result["body"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel.json_body
|
||||||
|
|
||||||
def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
|
def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
|
||||||
"""Create a ``FakeOidcServer``.
|
"""Create a ``FakeOidcServer``.
|
||||||
|
|
||||||
|
@ -572,6 +600,7 @@ class RestHelper:
|
||||||
fake_server: FakeOidcServer,
|
fake_server: FakeOidcServer,
|
||||||
remote_user_id: str,
|
remote_user_id: str,
|
||||||
with_sid: bool = False,
|
with_sid: bool = False,
|
||||||
|
idp_id: Optional[str] = None,
|
||||||
expected_status: int = 200,
|
expected_status: int = 200,
|
||||||
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
|
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
|
||||||
"""Log in (as a new user) via OIDC
|
"""Log in (as a new user) via OIDC
|
||||||
|
@ -588,7 +617,11 @@ class RestHelper:
|
||||||
client_redirect_url = "https://x"
|
client_redirect_url = "https://x"
|
||||||
userinfo = {"sub": remote_user_id}
|
userinfo = {"sub": remote_user_id}
|
||||||
channel, grant = self.auth_via_oidc(
|
channel, grant = self.auth_via_oidc(
|
||||||
fake_server, userinfo, client_redirect_url, with_sid=with_sid
|
fake_server,
|
||||||
|
userinfo,
|
||||||
|
client_redirect_url,
|
||||||
|
with_sid=with_sid,
|
||||||
|
idp_id=idp_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# expect a confirmation page
|
# expect a confirmation page
|
||||||
|
@ -623,6 +656,7 @@ class RestHelper:
|
||||||
client_redirect_url: Optional[str] = None,
|
client_redirect_url: Optional[str] = None,
|
||||||
ui_auth_session_id: Optional[str] = None,
|
ui_auth_session_id: Optional[str] = None,
|
||||||
with_sid: bool = False,
|
with_sid: bool = False,
|
||||||
|
idp_id: Optional[str] = None,
|
||||||
) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
|
) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
|
||||||
"""Perform an OIDC authentication flow via a mock OIDC provider.
|
"""Perform an OIDC authentication flow via a mock OIDC provider.
|
||||||
|
|
||||||
|
@ -648,6 +682,7 @@ class RestHelper:
|
||||||
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
|
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
|
||||||
of the UI auth.
|
of the UI auth.
|
||||||
with_sid: if True, generates a random `sid` (OIDC session ID)
|
with_sid: if True, generates a random `sid` (OIDC session ID)
|
||||||
|
idp_id: if set, explicitely chooses one specific IDP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A FakeChannel containing the result of calling the OIDC callback endpoint.
|
A FakeChannel containing the result of calling the OIDC callback endpoint.
|
||||||
|
@ -665,7 +700,9 @@ class RestHelper:
|
||||||
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
|
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
|
||||||
else:
|
else:
|
||||||
# otherwise, hit the login redirect endpoint
|
# otherwise, hit the login redirect endpoint
|
||||||
oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
|
oauth_uri = self.initiate_sso_login(
|
||||||
|
client_redirect_url, cookies, idp_id=idp_id
|
||||||
|
)
|
||||||
|
|
||||||
# we now have a URI for the OIDC IdP, but we skip that and go straight
|
# we now have a URI for the OIDC IdP, but we skip that and go straight
|
||||||
# back to synapse's OIDC callback resource. However, we do need the "state"
|
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||||
|
@ -742,7 +779,10 @@ class RestHelper:
|
||||||
return channel, grant
|
return channel, grant
|
||||||
|
|
||||||
def initiate_sso_login(
|
def initiate_sso_login(
|
||||||
self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
|
self,
|
||||||
|
client_redirect_url: Optional[str],
|
||||||
|
cookies: MutableMapping[str, str],
|
||||||
|
idp_id: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Make a request to the login-via-sso redirect endpoint, and return the target
|
"""Make a request to the login-via-sso redirect endpoint, and return the target
|
||||||
|
|
||||||
|
@ -753,6 +793,7 @@ class RestHelper:
|
||||||
client_redirect_url: the client redirect URL to pass to the login redirect
|
client_redirect_url: the client redirect URL to pass to the login redirect
|
||||||
endpoint
|
endpoint
|
||||||
cookies: any cookies returned will be added to this dict
|
cookies: any cookies returned will be added to this dict
|
||||||
|
idp_id: if set, explicitely chooses one specific IDP
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the URI that the client gets redirected to (ie, the SSO server)
|
the URI that the client gets redirected to (ie, the SSO server)
|
||||||
|
@ -761,6 +802,12 @@ class RestHelper:
|
||||||
if client_redirect_url:
|
if client_redirect_url:
|
||||||
params["redirectUrl"] = client_redirect_url
|
params["redirectUrl"] = client_redirect_url
|
||||||
|
|
||||||
|
uri = "/_matrix/client/r0/login/sso/redirect"
|
||||||
|
if idp_id is not None:
|
||||||
|
uri = f"{uri}/{idp_id}"
|
||||||
|
|
||||||
|
uri = f"{uri}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
# hit the redirect url (which should redirect back to the redirect url. This
|
# hit the redirect url (which should redirect back to the redirect url. This
|
||||||
# is the easiest way of figuring out what the Host header ought to be set to
|
# is the easiest way of figuring out what the Host header ought to be set to
|
||||||
# to keep Synapse happy.
|
# to keep Synapse happy.
|
||||||
|
@ -768,7 +815,7 @@ class RestHelper:
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
self.site,
|
self.site,
|
||||||
"GET",
|
"GET",
|
||||||
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
|
uri,
|
||||||
)
|
)
|
||||||
assert channel.code == 302
|
assert channel.code == 302
|
||||||
|
|
||||||
|
|
|
@ -362,6 +362,12 @@ def make_request(
|
||||||
# Twisted expects to be at the end of the content when parsing the request.
|
# Twisted expects to be at the end of the content when parsing the request.
|
||||||
req.content.seek(0, SEEK_END)
|
req.content.seek(0, SEEK_END)
|
||||||
|
|
||||||
|
# Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
|
||||||
|
# bodies if the Content-Length header is missing
|
||||||
|
req.requestHeaders.addRawHeader(
|
||||||
|
b"Content-Length", str(len(content)).encode("ascii")
|
||||||
|
)
|
||||||
|
|
||||||
if access_token:
|
if access_token:
|
||||||
req.requestHeaders.addRawHeader(
|
req.requestHeaders.addRawHeader(
|
||||||
b"Authorization", b"Bearer " + access_token.encode("ascii")
|
b"Authorization", b"Bearer " + access_token.encode("ascii")
|
||||||
|
|
|
@ -51,6 +51,8 @@ class FakeOidcServer:
|
||||||
get_userinfo_handler: Mock
|
get_userinfo_handler: Mock
|
||||||
post_token_handler: Mock
|
post_token_handler: Mock
|
||||||
|
|
||||||
|
sid_counter: int = 0
|
||||||
|
|
||||||
def __init__(self, clock: Clock, issuer: str):
|
def __init__(self, clock: Clock, issuer: str):
|
||||||
from authlib.jose import ECKey, KeySet
|
from authlib.jose import ECKey, KeySet
|
||||||
|
|
||||||
|
@ -146,7 +148,7 @@ class FakeOidcServer:
|
||||||
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
|
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
|
||||||
|
|
||||||
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
|
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
|
||||||
now = self._clock.time()
|
now = int(self._clock.time())
|
||||||
id_token = {
|
id_token = {
|
||||||
**grant.userinfo,
|
**grant.userinfo,
|
||||||
"iss": self.issuer,
|
"iss": self.issuer,
|
||||||
|
@ -166,6 +168,26 @@ class FakeOidcServer:
|
||||||
|
|
||||||
return self._sign(id_token)
|
return self._sign(id_token)
|
||||||
|
|
||||||
|
def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
|
||||||
|
now = int(self._clock.time())
|
||||||
|
logout_token = {
|
||||||
|
"iss": self.issuer,
|
||||||
|
"aud": grant.client_id,
|
||||||
|
"iat": now,
|
||||||
|
"jti": random_string(10),
|
||||||
|
"events": {
|
||||||
|
"http://schemas.openid.net/event/backchannel-logout": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if grant.sid is not None:
|
||||||
|
logout_token["sid"] = grant.sid
|
||||||
|
|
||||||
|
if "sub" in grant.userinfo:
|
||||||
|
logout_token["sub"] = grant.userinfo["sub"]
|
||||||
|
|
||||||
|
return self._sign(logout_token)
|
||||||
|
|
||||||
def id_token_override(self, overrides: dict):
|
def id_token_override(self, overrides: dict):
|
||||||
"""Temporarily patch the ID token generated by the token endpoint."""
|
"""Temporarily patch the ID token generated by the token endpoint."""
|
||||||
return patch.object(self, "_id_token_overrides", overrides)
|
return patch.object(self, "_id_token_overrides", overrides)
|
||||||
|
@ -183,7 +205,8 @@ class FakeOidcServer:
|
||||||
code = random_string(10)
|
code = random_string(10)
|
||||||
sid = None
|
sid = None
|
||||||
if with_sid:
|
if with_sid:
|
||||||
sid = random_string(10)
|
sid = str(self.sid_counter)
|
||||||
|
self.sid_counter += 1
|
||||||
|
|
||||||
grant = FakeAuthorizationGrant(
|
grant = FakeAuthorizationGrant(
|
||||||
userinfo=userinfo,
|
userinfo=userinfo,
|
||||||
|
|
Loading…
Reference in a new issue