Replace pyjwt with authlib in org.matrix.login.jwt (#13011)

This commit is contained in:
Hannes Lerchl 2022-06-15 18:45:16 +02:00 committed by GitHub
parent e12ff697a4
commit 7d99414edf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 100 additions and 57 deletions

1
changelog.d/13011.misc Normal file
View file

@ -0,0 +1 @@
Replaced usage of PyJWT with methods from Authlib in `org.matrix.login.jwt`. Contributed by Hannes Lerchl.

View file

@ -37,19 +37,19 @@ As with other login types, there are additional fields (e.g. `device_id` and
## Preparing Synapse ## Preparing Synapse
The JSON Web Token integration in Synapse uses the The JSON Web Token integration in Synapse uses the
[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed [`Authlib`](https://docs.authlib.org/en/latest/index.html) library, which must be installed
as follows: as follows:
* The relevant libraries are included in the Docker images and Debian packages * The relevant libraries are included in the Docker images and Debian packages
provided by `matrix.org` so no further action is needed. provided by `matrix.org` so no further action is needed.
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
install synapse[pyjwt]` to install the necessary dependencies. install synapse[jwt]` to install the necessary dependencies.
* For other installation mechanisms, see the documentation provided by the * For other installation mechanisms, see the documentation provided by the
maintainer. maintainer.
To enable the JSON web token integration, you should then add an `jwt_config` section To enable the JSON web token integration, you should then add a `jwt_config` section
to your configuration file (or uncomment the `enabled: true` line in the to your configuration file (or uncomment the `enabled: true` line in the
existing section). See [sample_config.yaml](./sample_config.yaml) for some existing section). See [sample_config.yaml](./sample_config.yaml) for some
sample settings. sample settings.
@ -57,7 +57,7 @@ sample settings.
## How to test JWT as a developer ## How to test JWT as a developer
Although JSON Web Tokens are typically generated from an external server, the Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly. example below uses a locally generated JWT.
1. Configure Synapse with JWT logins, note that this example uses a pre-shared 1. Configure Synapse with JWT logins, note that this example uses a pre-shared
secret and an algorithm of HS256: secret and an algorithm of HS256:
@ -70,10 +70,21 @@ examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
``` ```
2. Generate a JSON web token: 2. Generate a JSON web token:
```bash You can use the following short Python snippet to generate a JWT
$ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user protected by an HMAC.
eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc Take care that the `secret` and the algorithm given in the `header` match
the entries from `jwt_config` above.
```python
from authlib.jose import jwt
header = {"alg": "HS256"}
payload = {"sub": "user1", "aud": ["audience"]}
secret = "my-secret-token"
result = jwt.encode(header, payload, secret)
print(result.decode("ascii"))
``` ```
3. Query for the login types and ensure `org.matrix.login.jwt` is there: 3. Query for the login types and ensure `org.matrix.login.jwt` is there:
```bash ```bash

View file

@ -2946,8 +2946,10 @@ Additional sub-options for this setting include:
tokens. Defaults to false. tokens. Defaults to false.
* `secret`: This is either the private shared secret or the public key used to * `secret`: This is either the private shared secret or the public key used to
decode the contents of the JSON web token. Required if `enabled` is set to true. decode the contents of the JSON web token. Required if `enabled` is set to true.
* `algorithm`: The algorithm used to sign the JSON web token. Supported algorithms are listed at * `algorithm`: The algorithm used to sign (or HMAC) the JSON web token.
https://pyjwt.readthedocs.io/en/latest/algorithms.html Required if `enabled` is set to true. Supported algorithms are listed
[here (section JWS)](https://docs.authlib.org/en/latest/specs/rfc7518.html).
Required if `enabled` is set to true.
* `subject_claim`: Name of the claim containing a unique identifier for the user. * `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`. Optional, defaults to `sub`.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the * `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the

8
poetry.lock generated
View file

@ -815,7 +815,7 @@ python-versions = ">=3.5"
name = "pyjwt" name = "pyjwt"
version = "2.4.0" version = "2.4.0"
description = "JSON Web Token implementation in Python" description = "JSON Web Token implementation in Python"
category = "main" category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -1546,9 +1546,9 @@ docs = ["sphinx", "repoze.sphinx.autointerface"]
test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"] test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
[extras] [extras]
all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "pyjwt", "txredisapi", "hiredis", "Pympler"] all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler"]
cache_memory = ["Pympler"] cache_memory = ["Pympler"]
jwt = ["pyjwt"] jwt = ["authlib"]
matrix-synapse-ldap3 = ["matrix-synapse-ldap3"] matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
oidc = ["authlib"] oidc = ["authlib"]
opentracing = ["jaeger-client", "opentracing"] opentracing = ["jaeger-client", "opentracing"]
@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7.1" python-versions = "^3.7.1"
content-hash = "37bd4bccfdb5a869635f2135a85bea4a0729af7375a27de153b4fd9a4aebc195" content-hash = "73882e279e0379482f2fc7414cb71addfd408ca48ad508ff8a02b0cb544762af"
[metadata.files] [metadata.files]
attrs = [ attrs = [

View file

@ -175,7 +175,6 @@ lxml = { version = ">=4.2.0", optional = true }
sentry-sdk = { version = ">=0.7.2", optional = true } sentry-sdk = { version = ">=0.7.2", optional = true }
opentracing = { version = ">=2.2.0", optional = true } opentracing = { version = ">=2.2.0", optional = true }
jaeger-client = { version = ">=4.0.0", optional = true } jaeger-client = { version = ">=4.0.0", optional = true }
pyjwt = { version = ">=1.6.4", optional = true }
txredisapi = { version = ">=1.4.7", optional = true } txredisapi = { version = ">=1.4.7", optional = true }
hiredis = { version = "*", optional = true } hiredis = { version = "*", optional = true }
Pympler = { version = "*", optional = true } Pympler = { version = "*", optional = true }
@ -196,7 +195,7 @@ systemd = ["systemd-python"]
url_preview = ["lxml"] url_preview = ["lxml"]
sentry = ["sentry-sdk"] sentry = ["sentry-sdk"]
opentracing = ["jaeger-client", "opentracing"] opentracing = ["jaeger-client", "opentracing"]
jwt = ["pyjwt"] jwt = ["authlib"]
# hiredis is not a *strict* dependency, but it makes things much faster. # hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.) # (if it is not installed, we fall back to slow code.)
redis = ["txredisapi", "hiredis"] redis = ["txredisapi", "hiredis"]
@ -222,7 +221,7 @@ all = [
"psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat",
# saml2 # saml2
"pysaml2", "pysaml2",
# oidc # oidc and jwt
"authlib", "authlib",
# url_preview # url_preview
"lxml", "lxml",
@ -230,8 +229,6 @@ all = [
"sentry-sdk", "sentry-sdk",
# opentracing # opentracing
"jaeger-client", "opentracing", "jaeger-client", "opentracing",
# jwt
"pyjwt",
# redis # redis
"txredisapi", "hiredis", "txredisapi", "hiredis",
# cache_memory # cache_memory

View file

@ -18,10 +18,10 @@ from synapse.types import JsonDict
from ._base import Config, ConfigError from ._base import Config, ConfigError
MISSING_JWT = """Missing jwt library. This is required for jwt login. MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.
Install by running: Install by running:
pip install pyjwt pip install synapse[jwt]
""" """
@ -43,11 +43,11 @@ class JWTConfig(Config):
self.jwt_audiences = jwt_config.get("audiences") self.jwt_audiences = jwt_config.get("audiences")
try: try:
import jwt from authlib.jose import JsonWebToken
jwt # To stop unused lint. JsonWebToken # To stop unused lint.
except ImportError: except ImportError:
raise ConfigError(MISSING_JWT) raise ConfigError(MISSING_AUTHLIB)
else: else:
self.jwt_enabled = False self.jwt_enabled = False
self.jwt_secret = None self.jwt_secret = None

View file

@ -420,17 +420,31 @@ class LoginRestServlet(RestServlet):
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
) )
import jwt from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
jwt = JsonWebToken([self.jwt_algorithm])
claim_options = {}
if self.jwt_issuer is not None:
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
if self.jwt_audiences is not None:
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
try: try:
payload = jwt.decode( claims = jwt.decode(
token, token,
self.jwt_secret, key=self.jwt_secret,
algorithms=[self.jwt_algorithm], claims_cls=JWTClaims,
issuer=self.jwt_issuer, claims_options=claim_options,
audience=self.jwt_audiences,
) )
except jwt.PyJWTError as e: except BadSignatureError:
# We handle this case separately to provide a better error message
raise LoginError(
403,
"JWT validation failed: Signature verification failed",
errcode=Codes.FORBIDDEN,
)
except JoseError as e:
# A JWT error occurred, return some info back to the client. # A JWT error occurred, return some info back to the client.
raise LoginError( raise LoginError(
403, 403,
@ -438,7 +452,23 @@ class LoginRestServlet(RestServlet):
errcode=Codes.FORBIDDEN, errcode=Codes.FORBIDDEN,
) )
user = payload.get(self.jwt_subject_claim, None) try:
claims.validate(leeway=120) # allows 2 min of clock skew
# Enforce the old behavior which is rolled out in productive
# servers: if the JWT contains an 'aud' claim but none is
# configured, the login attempt will fail
if claims.get("aud") is not None:
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
raise InvalidClaimError("aud")
except JoseError as e:
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
user = claims.get(self.jwt_subject_claim, None)
if user is None: if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

View file

@ -14,7 +14,7 @@
import json import json
import time import time
import urllib.parse import urllib.parse
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from urllib.parse import urlencode from urllib.parse import urlencode
@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless from tests.unittest import HomeserverTestCase, override_config, skip_unless
try: try:
import jwt from authlib.jose import jwk, jwt
HAS_JWT = True HAS_JWT = True
except ImportError: except ImportError:
@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"]) self.assertIn(b"SSO account deactivated", channel.result["body"])
@skip_unless(HAS_JWT, "requires jwt") @skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase): class JWTTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return config return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. header = {"alg": self.jwt_algorithm}
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) result: bytes = jwt.encode(header, payload, secret)
if isinstance(result, bytes):
return result.decode("ascii") return result.decode("ascii")
return result
def jwt_login(self, *args: Any) -> FakeChannel: def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired" channel.json_body["error"],
"JWT validation failed: expired_token: The token is expired",
) )
def test_login_jwt_not_before(self) -> None: def test_login_jwt_not_before(self) -> None:
@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)", "JWT validation failed: invalid_token: The token is not valid yet",
) )
def test_login_no_sub(self) -> None: def test_login_no_sub(self) -> None:
@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer" channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "iss"',
) )
# Not providing an issuer. # Not providing an issuer.
@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim', 'JWT validation failed: missing_claim: Missing "iss" claim',
) )
def test_login_iss_no_config(self) -> None: def test_login_iss_no_config(self) -> None:
@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience" channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "aud"',
) )
# Not providing an audience. # Not providing an audience.
@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim', 'JWT validation failed: missing_claim: Missing "aud" claim',
) )
def test_login_aud_no_config(self) -> None: def test_login_aud_no_config(self) -> None:
@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience" channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "aud"',
) )
def test_login_default_sub(self) -> None: def test_login_default_sub(self) -> None:
@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key. # signed by the private key.
@skip_unless(HAS_JWT, "requires jwt") @skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase): class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return config return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. header = {"alg": "RS256"}
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
if isinstance(result, bytes): secret = jwk.dumps(secret, kty="RSA")
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii") return result.decode("ascii")
return result
def jwt_login(self, *args: Any) -> FakeChannel: def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}