0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-01 10:18:54 +02:00

Add the option to validate the iss and aud claims for JWT logins. (#7827)

This commit is contained in:
Patrick Cloke 2020-07-14 07:16:43 -04:00 committed by GitHub
parent 4db1509516
commit 77d2c05410
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 182 additions and 15 deletions

1
changelog.d/7827.feature Normal file
View file

@ -0,0 +1 @@
Add the option to validate the `iss` and `aud` claims for JWT logins.

View file

@ -20,8 +20,17 @@ follows:
Note that the login type of `m.login.jwt` is supported, but is deprecated. This
will be removed in a future version of Synapse.
The `jwt` should encode the local part of the user ID as the standard `sub`
claim. In the case that the token is not valid, the homeserver must respond with
The `token` field should include the JSON web token with the following claims:
* The `sub` (subject) claim is required and should encode the local part of the
user ID.
* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
claims are optional, but validated if present.
* The issuer (`iss`) claim is optional, but required and validated if configured.
* The audience (`aud`) claim is optional, but required and validated if configured.
Providing the audience claim when not configured will cause validation to fail.
In the case that the token is not valid, the homeserver must respond with
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
(Note that this differs from the token based logins which return a
@ -55,7 +64,8 @@ sample settings.
Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
1. Configure Synapse with JWT logins:
1. Configure Synapse with JWT logins, note that this example uses a pre-shared
secret and an algorithm of HS256:
```yaml
jwt_config:

View file

@ -1812,6 +1812,9 @@ sso:
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
# Additionally, the expiration time ("exp"), not before time ("nbf"),
# and issued at ("iat") claims are validated if present.
#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@ -1839,6 +1842,24 @@ sso:
#
#algorithm: "provided-by-your-issuer"
# The issuer to validate the "iss" claim against.
#
# Optional, if provided the "iss" claim will be required and
# validated for all JSON web tokens.
#
#issuer: "provided-by-your-issuer"
# A list of audiences to validate the "aud" claim against.
#
# Optional, if provided the "aud" claim will be required and
# validated for all JSON web tokens.
#
# Note that if the "aud" claim is included in a JSON web token then
# validation will fail without configuring audiences.
#
#audiences:
# - "provided-by-your-issuer"
password_config:
# Uncomment to disable password login

View file

@ -32,6 +32,11 @@ class JWTConfig(Config):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
# The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT.
self.jwt_issuer = jwt_config.get("issuer")
self.jwt_audiences = jwt_config.get("audiences")
try:
import jwt
@ -42,6 +47,8 @@ class JWTConfig(Config):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
self.jwt_issuer = None
self.jwt_audiences = None
def generate_config_section(self, **kwargs):
return """\
@ -52,6 +59,9 @@ class JWTConfig(Config):
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
# Additionally, the expiration time ("exp"), not before time ("nbf"),
# and issued at ("iat") claims are validated if present.
#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
@ -78,4 +88,22 @@ class JWTConfig(Config):
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
# The issuer to validate the "iss" claim against.
#
# Optional, if provided the "iss" claim will be required and
# validated for all JSON web tokens.
#
#issuer: "provided-by-your-issuer"
# A list of audiences to validate the "aud" claim against.
#
# Optional, if provided the "aud" claim will be required and
# validated for all JSON web tokens.
#
# Note that if the "aud" claim is included in a JSON web token then
# validation will fail without configuring audiences.
#
#audiences:
# - "provided-by-your-issuer"
"""

View file

@ -89,12 +89,19 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.jwt_issuer = hs.config.jwt_issuer
self.jwt_audiences = hs.config.jwt_audiences
# SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@ -368,16 +375,22 @@ class LoginRestServlet(RestServlet):
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
token,
self.jwt_secret,
algorithms=[self.jwt_algorithm],
issuer=self.jwt_issuer,
audience=self.jwt_audiences,
)
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
401,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.UNAUTHORIZED,
)
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:

View file

@ -514,16 +514,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
self.hs.config.jwt_algorithm = "HS256"
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
return jwt.encode(token, secret, "HS256").decode("ascii")
return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
params = json.dumps(
@ -548,20 +549,28 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "JWT expired")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired"
)
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)",
)
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
@ -569,6 +578,88 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"issuer": "test-issuer",
}
}
)
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer"
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim',
)
def test_login_iss_no_config(self):
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"audiences": ["test-audience"],
}
}
)
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim',
)
def test_login_aud_no_config(self):
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
@ -658,4 +749,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)