0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-03 11:18:56 +02:00

JWT OIDC secrets for Sign in with Apple (#9549)

Apple had to be special. They want a client secret which is generated from an EC key.

Fixes #9220. Also fixes #9212 while I'm here.
This commit is contained in:
Richard van der Hoff 2021-03-09 15:03:37 +00:00 committed by GitHub
parent 9cd18cc588
commit eaada74075
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 444 additions and 47 deletions

View file

@ -20,9 +20,10 @@ recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
include tests/http/ca.crt recursive-include tests *.pem
include tests/http/ca.key recursive-include tests *.p8
include tests/http/server.key recursive-include tests *.crt
recursive-include tests *.key
recursive-include synapse/res * recursive-include synapse/res *
recursive-include synapse/static *.css recursive-include synapse/static *.css

1
changelog.d/9549.feature Normal file
View file

@ -0,0 +1 @@
Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets.

View file

@ -386,7 +386,7 @@ oidc_providers:
config: config:
subject_claim: "id" subject_claim: "id"
localpart_template: "{{ user.login }}" localpart_template: "{{ user.login }}"
display_name_template: "{{ user.full_name }}" display_name_template: "{{ user.full_name }}"
``` ```
### XWiki ### XWiki
@ -401,8 +401,7 @@ oidc_providers:
idp_name: "XWiki" idp_name: "XWiki"
issuer: "https://myxwikihost/xwiki/oidc/" issuer: "https://myxwikihost/xwiki/oidc/"
client_id: "your-client-id" # TO BE FILLED client_id: "your-client-id" # TO BE FILLED
# Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed client_auth_method: none
client_secret: "dontcare"
scopes: ["openid", "profile"] scopes: ["openid", "profile"]
user_profile_method: "userinfo_endpoint" user_profile_method: "userinfo_endpoint"
user_mapping_provider: user_mapping_provider:
@ -410,3 +409,40 @@ oidc_providers:
localpart_template: "{{ user.preferred_username }}" localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}" display_name_template: "{{ user.name }}"
``` ```
## Apple
Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
You will need to create a new "Services ID" for SiWA, and create and download a
private key with "SiWA" enabled.
As well as the private key file, you will need:
* Client ID: the "identifier" you gave the "Services ID"
* Team ID: a 10-character ID associated with your developer account.
* Key ID: the 10-character identifier for the key.
https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
documentation on setting up SiWA.
The synapse config will look like this:
```yaml
- idp_id: apple
idp_name: Apple
issuer: "https://appleid.apple.com"
client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
client_auth_method: "client_secret_post"
client_secret_jwt_key:
key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
jwt_header:
alg: ES256
kid: "KEYIDCODE" # Set to the 10-char Key ID
jwt_payload:
iss: TEAMIDCODE # Set to the 10-char Team ID
scopes: ["name", "email", "openid"]
authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
user_mapping_provider:
config:
email_template: "{{ user.email }}"
```

View file

@ -1779,7 +1779,26 @@ saml2_config:
# #
# client_id: Required. oauth2 client id to use. # client_id: Required. oauth2 client id to use.
# #
# client_secret: Required. oauth2 client secret to use. # client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
# #
# client_auth_method: auth method to use when exchanging the token. Valid # client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and # values are 'client_secret_basic' (default), 'client_secret_post' and

View file

@ -212,9 +212,8 @@ class Config:
@classmethod @classmethod
def read_file(cls, file_path, config_name): def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name) """Deprecated: call read_file directly"""
with open(file_path) as file_stream: return read_file(file_path, (config_name,))
return file_stream.read()
def read_template(self, filename: str) -> jinja2.Template: def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk. """Load a template file from disk.
@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
return self._get_instance(key) return self._get_instance(key)
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] def read_file(file_path: Any, config_path: Iterable[str]) -> str:
"""Check the given file exists, and read it into a string
If it does not, emit an error indicating the problem
Args:
file_path: the file to be read
config_path: where in the configuration file_path came from, so that a useful
error can be emitted if it does not exist.
Returns:
content of the file.
Raises:
ConfigError if there is a problem reading the file.
"""
if not isinstance(file_path, str):
raise ConfigError("%r is not a string", config_path)
try:
os.stat(file_path)
with open(file_path) as file_stream:
return file_stream.read()
except OSError as e:
raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
__all__ = [
"Config",
"RootConfig",
"ShardedWorkerHandlingConfig",
"RoutableShardedWorkerHandlingConfig",
"read_file",
]

View file

@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ... def get_instance(self, key: str) -> str: ...
def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from collections import Counter from collections import Counter
from typing import Iterable, Optional, Tuple, Type from typing import Iterable, Mapping, Optional, Tuple, Type
import attr import attr
@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri from synapse.util.stringutils import parse_and_validate_mxc_uri
from ._base import Config, ConfigError from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@ -97,7 +97,26 @@ class OIDCConfig(Config):
# #
# client_id: Required. oauth2 client id to use. # client_id: Required. oauth2 client id to use.
# #
# client_secret: Required. oauth2 client secret to use. # client_secret: oauth2 client secret to use. May be omitted if
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
#
# client_secret_jwt_key: Alternative to client_secret: details of a key used
# to create a JSON Web Token to be used as an OAuth2 client secret. If
# given, must be a dictionary with the following properties:
#
# key: a pem-encoded signing key. Must be a suitable key for the
# algorithm specified. Required unless 'key_file' is given.
#
# key_file: the path to file containing a pem-encoded signing key file.
# Required unless 'key' is given.
#
# jwt_header: a dictionary giving properties to include in the JWT
# header. Must include the key 'alg', giving the algorithm used to
# sign the JWT, such as "ES256", using the JWA identifiers in
# RFC7518.
#
# jwt_payload: an optional dictionary giving properties to include in
# the JWT payload. Normally this should include an 'iss' key.
# #
# client_auth_method: auth method to use when exchanging the token. Valid # client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and # values are 'client_secret_basic' (default), 'client_secret_post' and
@ -240,7 +259,7 @@ class OIDCConfig(Config):
# jsonschema definition of the configuration settings for an oidc identity provider # jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = { OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object", "type": "object",
"required": ["issuer", "client_id", "client_secret"], "required": ["issuer", "client_id"],
"properties": { "properties": {
"idp_id": { "idp_id": {
"type": "string", "type": "string",
@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"issuer": {"type": "string"}, "issuer": {"type": "string"},
"client_id": {"type": "string"}, "client_id": {"type": "string"},
"client_secret": {"type": "string"}, "client_secret": {"type": "string"},
"client_secret_jwt_key": {
"type": "object",
"required": ["jwt_header"],
"oneOf": [
{"required": ["key"]},
{"required": ["key_file"]},
],
"properties": {
"key": {"type": "string"},
"key_file": {"type": "string"},
"jwt_header": {
"type": "object",
"required": ["alg"],
"properties": {
"alg": {"type": "string"},
},
"additionalProperties": {"type": "string"},
},
"jwt_payload": {
"type": "object",
"additionalProperties": {"type": "string"},
},
},
},
"client_auth_method": { "client_auth_method": {
"type": "string", "type": "string",
# the following list is the same as the keys of # the following list is the same as the keys of
@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",) "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e ) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile:
key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
else:
key = client_secret_jwt_key_config["key"]
client_secret_jwt_key = OidcProviderClientSecretJwtKey(
key=key,
jwt_header=client_secret_jwt_key_config["jwt_header"],
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
)
return OidcProviderConfig( return OidcProviderConfig(
idp_id=idp_id, idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"), idp_name=oidc_config.get("idp_name", "OIDC"),
@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
discover=oidc_config.get("discover", True), discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"], issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"], client_id=oidc_config["client_id"],
client_secret=oidc_config["client_secret"], client_secret=oidc_config.get("client_secret"),
client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]), scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"), authorization_endpoint=oidc_config.get("authorization_endpoint"),
@ -427,6 +485,18 @@ def _parse_oidc_config_dict(
) )
@attr.s(slots=True, frozen=True)
class OidcProviderClientSecretJwtKey:
# a pem-encoded signing key
key = attr.ib(type=str)
# properties to include in the JWT header
jwt_header = attr.ib(type=Mapping[str, str])
# properties to include in the JWT payload.
jwt_payload = attr.ib(type=Mapping[str, str])
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class OidcProviderConfig: class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids' # a unique identifier for this identity provider. Used in the 'user_external_ids'
@ -452,8 +522,13 @@ class OidcProviderConfig:
# oauth2 client id to use # oauth2 client id to use
client_id = attr.ib(type=str) client_id = attr.ib(type=str)
# oauth2 client secret to use # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
client_secret = attr.ib(type=str) # a secret.
client_secret = attr.ib(type=Optional[str])
# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
# auth method to use when exchanging the token. # auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and # Valid values are 'client_secret_basic', 'client_secret_post' and

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech # Copyright 2020 Quentin Gliech
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,13 +15,13 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode from urllib.parse import urlencode
import attr import attr
import pymacaroons import pymacaroons
from authlib.common.security import generate_token from authlib.common.security import generate_token
from authlib.jose import JsonWebToken from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@ -35,12 +36,15 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.oidc_config import OidcProviderConfig from synapse.config.oidc_config import (
OidcProviderClientSecretJwtKey,
OidcProviderConfig,
)
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@ -276,9 +280,21 @@ class OidcProvider:
self._scopes = provider.scopes self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method self._user_profile_method = provider.user_profile_method
client_secret = None # type: Union[None, str, JwtClientSecret]
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
client_secret = JwtClientSecret(
provider.client_secret_jwt_key,
provider.client_id,
provider.issuer,
hs.get_clock(),
)
self._client_auth = ClientAuth( self._client_auth = ClientAuth(
provider.client_id, provider.client_id,
provider.client_secret, client_secret,
provider.client_auth_method, provider.client_auth_method,
) # type: ClientAuth ) # type: ClientAuth
self._client_auth_method = provider.client_auth_method self._client_auth_method = provider.client_auth_method
@ -977,6 +993,81 @@ class OidcProvider:
return str(remote_user_id) return str(remote_user_id)
# number of seconds a newly-generated client secret should be valid for
CLIENT_SECRET_VALIDITY_SECONDS = 3600
# minimum remaining validity on a client secret before we should generate a new one
CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
class JwtClientSecret:
"""A class which generates a new client secret on demand, based on a JWK
This implementation is designed to comply with the requirements for Apple Sign in:
https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
but it's worth noting that we still put the generated secret in the "client_secret"
field (or rather, whereever client_auth_method puts it) rather than in a
client_assertion field in the body as that RFC seems to require.
"""
def __init__(
self,
key: OidcProviderClientSecretJwtKey,
oauth_client_id: str,
oauth_issuer: str,
clock: Clock,
):
self._key = key
self._oauth_client_id = oauth_client_id
self._oauth_issuer = oauth_issuer
self._clock = clock
self._cached_secret = b""
self._cached_secret_replacement_time = 0
def __str__(self):
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
def __bytes__(self):
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
def _get_secret(self) -> bytes:
now = self._clock.time()
# if we have enough validity on our existing secret, use it
if now < self._cached_secret_replacement_time:
return self._cached_secret
issued_at = int(now)
expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
# we copy the configured header because jwt.encode modifies it.
header = dict(self._key.jwt_header)
# see https://tools.ietf.org/html/rfc7523#section-3
payload = {
"sub": self._oauth_client_id,
"aud": self._oauth_issuer,
"iat": issued_at,
"exp": expires_at,
**self._key.jwt_payload,
}
logger.info(
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
)
self._cached_secret = jwt.encode(header, payload, self._key.key)
self._cached_secret_replacement_time = (
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
)
return self._cached_secret
class OidcSessionTokenGenerator: class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies.""" """Methods for generating and checking OIDC Session cookies."""

View file

@ -0,0 +1,5 @@
-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
-----END PRIVATE KEY-----

View file

@ -0,0 +1,4 @@
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
-----END PUBLIC KEY-----

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import os
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch from mock import ANY, Mock, patch
@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
JWKS_URI = ISSUER + ".well-known/jwks.json" JWKS_URI = ISSUER + ".well-known/jwks.json"
# config for common cases # config for common cases
COMMON_CONFIG = { DEFAULT_CONFIG = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# extends the default config with explicit OAuth2 endpoints instead of using discovery
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False, "discover": False,
"authorization_endpoint": AUTHORIZATION_ENDPOINT, "authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT, "token_endpoint": TOKEN_ENDPOINT,
@ -107,6 +119,32 @@ async def get_json(url):
return {"keys": []} return {"keys": []}
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
# this key was generated with:
# openssl ecparam -name prime256v1 -genkey -noout |
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
#
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
# that's what Apple use, and we want to be sure that we work with Apple's keys.
#
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
# really need to care about any of that.)
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
def _public_key_file_path() -> str:
"""path to a file containing the public half of a test key"""
# this was generated with:
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
#
# See above about where oidc_test_key.p8 came from
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
class OidcHandlerTestCase(HomeserverTestCase): class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC: if not HAS_OIDC:
skip = "requires OIDC" skip = "requires OIDC"
@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def default_config(self): def default_config(self):
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
oidc_config = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
# override_config works as expected.
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.reset_mock() self.render_error.reset_mock()
return args return args
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self): def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self): def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document.""" """The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid # This would throw if some metadata were invalid
@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self): def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document.""" """When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self): def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used.""" """JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks()) jwks = self.get_success(self.provider.load_jwks())
@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []}) self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self): def test_validate_config(self):
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
h = self.provider h = self.provider
@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Shouldn't raise with a valid userinfo, even without jwks # Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata() force_load_metadata()
@override_config({"oidc_config": {"skip_verification": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self): def test_skip_verification(self):
"""Provider metadata validation can be disabled by config.""" """Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}): with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw # This should not throw
get_awaitable_result(self.provider.load_metadata()) get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self): def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie.""" """The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"]) req = Mock(spec=["cookies"])
@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce]) self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect") self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self): def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed.""" """Errors from the provider returned in the callback are displayed."""
request = Mock(args={}) request = Mock(args={})
@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description") self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self): def test_callback(self):
"""Code callback works and display errors if something went wrong. """Code callback works and display errors if something went wrong.
@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request") self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self): def test_callback_session(self):
"""The callback verifies the session presence and validity""" """The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"]) request = Mock(spec=["args", "getCookie", "cookies"])
@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request") self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) @override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
def test_exchange_code(self): def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios.""" """Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"} token = {"type": "bearer"}
@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config( @override_config(
{ {
"oidc_config": { "oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "client_secret_post",
"client_secret_jwt_key": {
"key_file": _key_file_path(),
"jwt_header": {"alg": "ES256", "kid": "ABC789"},
"jwt_payload": {"iss": "DEFGHI"},
},
}
}
)
def test_exchange_code_jwt_key(self):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
)
)
code = "code"
# advance the clock a bit before we start, so we aren't working with zero
# timestamps.
self.reactor.advance(1000)
start_time = self.reactor.seconds()
ret = self.get_success(self.provider._exchange_code(code))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.http_client.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
# the client secret provided to the should be a jwt which can be checked with
# the public key
args = parse_qs(kwargs["data"].decode("utf-8"))
secret = args["client_secret"][0]
with open(_public_key_file_path()) as f:
key = f.read()
claims = jwt.decode(secret, key)
self.assertEqual(claims.header["kid"], "ABC789")
self.assertEqual(claims["aud"], ISSUER)
self.assertEqual(claims["iss"], "DEFGHI")
self.assertEqual(claims["sub"], CLIENT_ID)
self.assertEqual(claims["iat"], start_time)
self.assertGreater(claims["exp"], start_time)
# check the rest of the POSTed data
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "none",
}
}
)
def test_exchange_code_no_auth(self):
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
return_value=FakeResponse(
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
)
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.http_client.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": { "user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra" "module": __name__ + ".TestMappingProviderExtra"
} },
} }
} }
) )
@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
new_user=True, new_user=True,
) )
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self): def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs", "Mapping provider does not support de-duplicating Matrix IDs",
) )
@override_config({"oidc_config": {"allow_existing_users": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self): def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True.""" """Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore() store = self.hs.get_datastore()
@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
) )
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success( self.get_success(
@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config( @override_config(
{ {
"oidc_config": { "oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": { "user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures" "module": __name__ + ".TestMappingProviderFailures"
} },
} }
} }
) )
@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"mapping_error", "Unable to generate a Matrix ID from the SSO response" "mapping_error", "Unable to generate a Matrix ID from the SSO response"
) )
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self): def test_empty_localpart(self):
"""Attempts to map onto an empty localpart should be rejected.""" """Attempts to map onto an empty localpart should be rejected."""
userinfo = { userinfo = {
@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config( @override_config(
{ {
"oidc_config": { "oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": { "user_mapping_provider": {
"config": {"localpart_template": "{{ user.username }}"} "config": {"localpart_template": "{{ user.username }}"}
} },
} }
} }
) )