mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:13:50 +01:00
Fix a bug in automatic user creation with m.login.jwt. (#7585)
This commit is contained in:
parent
33c39ab93c
commit
fe434cd3c9
3 changed files with 162 additions and 7 deletions
1
changelog.d/7585.bugfix
Normal file
1
changelog.d/7585.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof.
|
|
@ -299,7 +299,7 @@ class LoginRestServlet(RestServlet):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _complete_login(
|
async def _complete_login(
|
||||||
self, user_id, login_submission, callback=None, create_non_existant_users=False
|
self, user_id, login_submission, callback=None, create_non_existent_users=False
|
||||||
):
|
):
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
|
@ -312,7 +312,7 @@ class LoginRestServlet(RestServlet):
|
||||||
user_id (str): ID of the user to register.
|
user_id (str): ID of the user to register.
|
||||||
login_submission (dict): Dictionary of login information.
|
login_submission (dict): Dictionary of login information.
|
||||||
callback (func|None): Callback function to run after registration.
|
callback (func|None): Callback function to run after registration.
|
||||||
create_non_existant_users (bool): Whether to create the user if
|
create_non_existent_users (bool): Whether to create the user if
|
||||||
they don't exist. Defaults to False.
|
they don't exist. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -331,12 +331,13 @@ class LoginRestServlet(RestServlet):
|
||||||
update=True,
|
update=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if create_non_existant_users:
|
if create_non_existent_users:
|
||||||
user_id = await self.auth_handler.check_user_exists(user_id)
|
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||||
if not user_id:
|
if not canonical_uid:
|
||||||
user_id = await self.registration_handler.register_user(
|
canonical_uid = await self.registration_handler.register_user(
|
||||||
localpart=UserID.from_string(user_id).localpart
|
localpart=UserID.from_string(user_id).localpart
|
||||||
)
|
)
|
||||||
|
user_id = canonical_uid
|
||||||
|
|
||||||
device_id = login_submission.get("device_id")
|
device_id = login_submission.get("device_id")
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
|
@ -391,7 +392,7 @@ class LoginRestServlet(RestServlet):
|
||||||
|
|
||||||
user_id = UserID(user, self.hs.hostname).to_string()
|
user_id = UserID(user, self.hs.hostname).to_string()
|
||||||
result = await self._complete_login(
|
result = await self._complete_login(
|
||||||
user_id, login_submission, create_non_existant_users=True
|
user_id, login_submission, create_non_existent_users=True
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.rest.client.v1 import login, logout
|
from synapse.rest.client.v1 import login, logout
|
||||||
from synapse.rest.client.v2_alpha import devices
|
from synapse.rest.client.v2_alpha import devices
|
||||||
|
@ -473,3 +476,153 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||||
# Because the user is deactivated they are served an error template.
|
# Because the user is deactivated they are served an error template.
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
self.assertIn(b"SSO account deactivated", channel.result["body"])
|
self.assertIn(b"SSO account deactivated", channel.result["body"])
|
||||||
|
|
||||||
|
|
||||||
|
class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
jwt_secret = "secret"
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
self.hs = self.setup_test_homeserver()
|
||||||
|
self.hs.config.jwt_enabled = True
|
||||||
|
self.hs.config.jwt_secret = self.jwt_secret
|
||||||
|
self.hs.config.jwt_algorithm = "HS256"
|
||||||
|
return self.hs
|
||||||
|
|
||||||
|
def jwt_encode(self, token, secret=jwt_secret):
|
||||||
|
return jwt.encode(token, secret, "HS256").decode("ascii")
|
||||||
|
|
||||||
|
def jwt_login(self, *args):
|
||||||
|
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
|
||||||
|
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
self.render(request)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
def test_login_jwt_valid_registered(self):
|
||||||
|
self.register_user("kermit", "monkey")
|
||||||
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
|
def test_login_jwt_valid_unregistered(self):
|
||||||
|
channel = self.jwt_login({"sub": "frog"})
|
||||||
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||||
|
|
||||||
|
def test_login_jwt_invalid_signature(self):
|
||||||
|
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
||||||
|
def test_login_jwt_expired(self):
|
||||||
|
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "JWT expired")
|
||||||
|
|
||||||
|
def test_login_jwt_not_before(self):
|
||||||
|
now = int(time.time())
|
||||||
|
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
||||||
|
def test_login_no_sub(self):
|
||||||
|
channel = self.jwt_login({"username": "root"})
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
||||||
|
def test_login_no_token(self):
|
||||||
|
params = json.dumps({"type": "m.login.jwt"})
|
||||||
|
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
|
||||||
|
|
||||||
|
|
||||||
|
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
|
||||||
|
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
|
||||||
|
# signed by the private key.
|
||||||
|
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
# This key's pubkey is used as the jwt_secret setting of synapse. Valid
|
||||||
|
# tokens are signed by this and validated using the pubkey. It is generated
|
||||||
|
# with `openssl genrsa 512` (not a secure way to generate real keys, but
|
||||||
|
# good enough for tests!)
|
||||||
|
jwt_privatekey = "\n".join(
|
||||||
|
[
|
||||||
|
"-----BEGIN RSA PRIVATE KEY-----",
|
||||||
|
"MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
|
||||||
|
"492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
|
||||||
|
"yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
|
||||||
|
"kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
|
||||||
|
"TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
|
||||||
|
"ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
|
||||||
|
"tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
|
||||||
|
"-----END RSA PRIVATE KEY-----",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generated with `openssl rsa -in foo.key -pubout`, with the the above
|
||||||
|
# private key placed in foo.key (jwt_privatekey).
|
||||||
|
jwt_pubkey = "\n".join(
|
||||||
|
[
|
||||||
|
"-----BEGIN PUBLIC KEY-----",
|
||||||
|
"MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
|
||||||
|
"TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
|
||||||
|
"-----END PUBLIC KEY-----",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This key is used to sign tokens that shouldn't be accepted by synapse.
|
||||||
|
# Generated just like jwt_privatekey.
|
||||||
|
bad_privatekey = "\n".join(
|
||||||
|
[
|
||||||
|
"-----BEGIN RSA PRIVATE KEY-----",
|
||||||
|
"MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
|
||||||
|
"gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
|
||||||
|
"R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
|
||||||
|
"uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
|
||||||
|
"eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
|
||||||
|
"iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
|
||||||
|
"KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
|
||||||
|
"-----END RSA PRIVATE KEY-----",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
self.hs = self.setup_test_homeserver()
|
||||||
|
self.hs.config.jwt_enabled = True
|
||||||
|
self.hs.config.jwt_secret = self.jwt_pubkey
|
||||||
|
self.hs.config.jwt_algorithm = "RS256"
|
||||||
|
return self.hs
|
||||||
|
|
||||||
|
def jwt_encode(self, token, secret=jwt_privatekey):
|
||||||
|
return jwt.encode(token, secret, "RS256").decode("ascii")
|
||||||
|
|
||||||
|
def jwt_login(self, *args):
|
||||||
|
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
|
||||||
|
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
self.render(request)
|
||||||
|
return channel
|
||||||
|
|
||||||
|
def test_login_jwt_valid(self):
|
||||||
|
channel = self.jwt_login({"sub": "kermit"})
|
||||||
|
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
||||||
|
|
||||||
|
def test_login_jwt_invalid_signature(self):
|
||||||
|
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
||||||
|
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||||
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||||
|
|
Loading…
Reference in a new issue