forked from MirrorHub/synapse
Add config for customizing the claim used for JWT logins. (#11361)
Allows specifying a different claim (from the default "sub") to use when calculating the localpart of the Matrix ID used during the JWT login.
This commit is contained in:
parent
3d893b8cf2
commit
1035663833
6 changed files with 57 additions and 35 deletions
1
changelog.d/11361.feature
Normal file
1
changelog.d/11361.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update the JWT login type to support custom a `sub` claim.
|
|
@ -22,8 +22,9 @@ will be removed in a future version of Synapse.
|
||||||
|
|
||||||
The `token` field should include the JSON web token with the following claims:
|
The `token` field should include the JSON web token with the following claims:
|
||||||
|
|
||||||
* The `sub` (subject) claim is required and should encode the local part of the
|
* A claim that encodes the local part of the user ID is required. By default,
|
||||||
user ID.
|
the `sub` (subject) claim is used, or a custom claim can be set in the
|
||||||
|
configuration file.
|
||||||
* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
|
* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
|
||||||
claims are optional, but validated if present.
|
claims are optional, but validated if present.
|
||||||
* The issuer (`iss`) claim is optional, but required and validated if configured.
|
* The issuer (`iss`) claim is optional, but required and validated if configured.
|
||||||
|
|
|
@ -2039,6 +2039,12 @@ sso:
|
||||||
#
|
#
|
||||||
#algorithm: "provided-by-your-issuer"
|
#algorithm: "provided-by-your-issuer"
|
||||||
|
|
||||||
|
# Name of the claim containing a unique identifier for the user.
|
||||||
|
#
|
||||||
|
# Optional, defaults to `sub`.
|
||||||
|
#
|
||||||
|
#subject_claim: "sub"
|
||||||
|
|
||||||
# The issuer to validate the "iss" claim against.
|
# The issuer to validate the "iss" claim against.
|
||||||
#
|
#
|
||||||
# Optional, if provided the "iss" claim will be required and
|
# Optional, if provided the "iss" claim will be required and
|
||||||
|
|
|
@ -31,6 +31,8 @@ class JWTConfig(Config):
|
||||||
self.jwt_secret = jwt_config["secret"]
|
self.jwt_secret = jwt_config["secret"]
|
||||||
self.jwt_algorithm = jwt_config["algorithm"]
|
self.jwt_algorithm = jwt_config["algorithm"]
|
||||||
|
|
||||||
|
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
|
||||||
|
|
||||||
# The issuer and audiences are optional, if provided, it is asserted
|
# The issuer and audiences are optional, if provided, it is asserted
|
||||||
# that the claims exist on the JWT.
|
# that the claims exist on the JWT.
|
||||||
self.jwt_issuer = jwt_config.get("issuer")
|
self.jwt_issuer = jwt_config.get("issuer")
|
||||||
|
@ -46,6 +48,7 @@ class JWTConfig(Config):
|
||||||
self.jwt_enabled = False
|
self.jwt_enabled = False
|
||||||
self.jwt_secret = None
|
self.jwt_secret = None
|
||||||
self.jwt_algorithm = None
|
self.jwt_algorithm = None
|
||||||
|
self.jwt_subject_claim = None
|
||||||
self.jwt_issuer = None
|
self.jwt_issuer = None
|
||||||
self.jwt_audiences = None
|
self.jwt_audiences = None
|
||||||
|
|
||||||
|
@ -88,6 +91,12 @@ class JWTConfig(Config):
|
||||||
#
|
#
|
||||||
#algorithm: "provided-by-your-issuer"
|
#algorithm: "provided-by-your-issuer"
|
||||||
|
|
||||||
|
# Name of the claim containing a unique identifier for the user.
|
||||||
|
#
|
||||||
|
# Optional, defaults to `sub`.
|
||||||
|
#
|
||||||
|
#subject_claim: "sub"
|
||||||
|
|
||||||
# The issuer to validate the "iss" claim against.
|
# The issuer to validate the "iss" claim against.
|
||||||
#
|
#
|
||||||
# Optional, if provided the "iss" claim will be required and
|
# Optional, if provided the "iss" claim will be required and
|
||||||
|
|
|
@ -72,6 +72,7 @@ class LoginRestServlet(RestServlet):
|
||||||
# JWT configuration variables.
|
# JWT configuration variables.
|
||||||
self.jwt_enabled = hs.config.jwt.jwt_enabled
|
self.jwt_enabled = hs.config.jwt.jwt_enabled
|
||||||
self.jwt_secret = hs.config.jwt.jwt_secret
|
self.jwt_secret = hs.config.jwt.jwt_secret
|
||||||
|
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
||||||
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
||||||
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
||||||
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
||||||
|
@ -413,7 +414,7 @@ class LoginRestServlet(RestServlet):
|
||||||
errcode=Codes.FORBIDDEN,
|
errcode=Codes.FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = payload.get("sub", None)
|
user = payload.get(self.jwt_subject_claim, None)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
|
|
@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
jwt_secret = "secret"
|
jwt_secret = "secret"
|
||||||
jwt_algorithm = "HS256"
|
jwt_algorithm = "HS256"
|
||||||
|
base_config = {
|
||||||
|
"enabled": True,
|
||||||
|
"secret": jwt_secret,
|
||||||
|
"algorithm": jwt_algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def default_config(self):
|
||||||
self.hs = self.setup_test_homeserver()
|
config = super().default_config()
|
||||||
self.hs.config.jwt.jwt_enabled = True
|
|
||||||
self.hs.config.jwt.jwt_secret = self.jwt_secret
|
# If jwt_config has been defined (eg via @override_config), don't replace it.
|
||||||
self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
|
if config.get("jwt_config") is None:
|
||||||
return self.hs
|
config["jwt_config"] = self.base_config
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
|
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
|
||||||
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
||||||
|
@ -879,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
||||||
@override_config(
|
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
|
||||||
{
|
|
||||||
"jwt_config": {
|
|
||||||
"jwt_enabled": True,
|
|
||||||
"secret": jwt_secret,
|
|
||||||
"algorithm": jwt_algorithm,
|
|
||||||
"issuer": "test-issuer",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
def test_login_iss(self):
|
def test_login_iss(self):
|
||||||
"""Test validating the issuer claim."""
|
"""Test validating the issuer claim."""
|
||||||
# A valid issuer.
|
# A valid issuer.
|
||||||
|
@ -919,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
@override_config(
|
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
|
||||||
{
|
|
||||||
"jwt_config": {
|
|
||||||
"jwt_enabled": True,
|
|
||||||
"secret": jwt_secret,
|
|
||||||
"algorithm": jwt_algorithm,
|
|
||||||
"audiences": ["test-audience"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
def test_login_aud(self):
|
def test_login_aud(self):
|
||||||
"""Test validating the audience claim."""
|
"""Test validating the audience claim."""
|
||||||
# A valid audience.
|
# A valid audience.
|
||||||
|
@ -962,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_login_default_sub(self):
|
||||||
|
"""Test reading user ID from the default subject claim."""
|
||||||
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
|
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
|
||||||
|
def test_login_custom_sub(self):
|
||||||
|
"""Test reading user ID from a custom subject claim."""
|
||||||
|
channel = self.jwt_login({"username": "frog"})
|
||||||
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||||
|
|
||||||
def test_login_no_token(self):
|
def test_login_no_token(self):
|
||||||
params = {"type": "org.matrix.login.jwt"}
|
params = {"type": "org.matrix.login.jwt"}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def default_config(self):
|
||||||
self.hs = self.setup_test_homeserver()
|
config = super().default_config()
|
||||||
self.hs.config.jwt.jwt_enabled = True
|
config["jwt_config"] = {
|
||||||
self.hs.config.jwt.jwt_secret = self.jwt_pubkey
|
"enabled": True,
|
||||||
self.hs.config.jwt.jwt_algorithm = "RS256"
|
"secret": self.jwt_pubkey,
|
||||||
return self.hs
|
"algorithm": "RS256",
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
|
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
|
||||||
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
|
||||||
|
|
Loading…
Reference in a new issue