mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 22:13:29 +01:00
Move token generation to auth handler
I prefer the auth handler to worry about all auth, and register to call into it as needed, than to smatter auth logic between the two.
This commit is contained in:
parent
ade5342752
commit
617501dd2a
3 changed files with 38 additions and 31 deletions
|
@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import pymacaroons
|
||||||
import simplejson
|
import simplejson
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
@ -284,12 +285,9 @@ class AuthHandler(BaseHandler):
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
yield self._check_password(user_id, password)
|
yield self._check_password(user_id, password)
|
||||||
|
|
||||||
reg_handler = self.hs.get_handlers().registration_handler
|
|
||||||
access_token = reg_handler.generate_token(user_id)
|
|
||||||
logger.info("Logging in user %s", user_id)
|
logger.info("Logging in user %s", user_id)
|
||||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
token = yield self.issue_access_token(user_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password(self, user_id, password):
|
def _check_password(self, user_id, password):
|
||||||
|
@ -304,6 +302,27 @@ class AuthHandler(BaseHandler):
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def issue_access_token(self, user_id):
|
||||||
|
reg_handler = self.hs.get_handlers().registration_handler
|
||||||
|
access_token = reg_handler.generate_access_token(user_id)
|
||||||
|
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||||
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
|
def generate_access_token(self, user_id):
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location = self.hs.config.server_name,
|
||||||
|
identifier = "key",
|
||||||
|
key = self.hs.config.macaroon_secret_key)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
now = self.hs.get_clock().time_msec()
|
||||||
|
expiry = now + (60 * 60 * 1000)
|
||||||
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword):
|
def set_password(self, user_id, newpassword):
|
||||||
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
||||||
|
|
|
@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import logging
|
import logging
|
||||||
import pymacaroons
|
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
400, "Invalid user localpart for this application service.",
|
400, "Invalid user localpart for this application service.",
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
try:
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(self, user_id):
|
|
||||||
macaroon = pymacaroons.Macaroon(
|
|
||||||
location = self.hs.config.server_name,
|
|
||||||
identifier = "key",
|
|
||||||
key = self.hs.config.macaroon_secret_key)
|
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
|
||||||
macaroon.add_first_party_caveat("type = access")
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
expiry = now + (60 * 60 * 1000)
|
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
|
||||||
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def _generate_user_id(self):
|
def _generate_user_id(self):
|
||||||
return "-" + stringutils.random_string(18)
|
return "-" + stringutils.random_string(18)
|
||||||
|
|
||||||
|
@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
|
def auth_handler(self):
|
||||||
|
return self.hs.get_handlers().auth_handler
|
||||||
|
|
|
@ -16,27 +16,27 @@
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
class RegisterHandlers(object):
|
class AuthHandlers(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.registration_handler = RegistrationHandler(hs)
|
self.auth_handler = AuthHandler(hs)
|
||||||
|
|
||||||
|
|
||||||
class RegisterTestCase(unittest.TestCase):
|
class AuthTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(handlers=None)
|
||||||
self.hs.handlers = RegisterHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
||||||
|
|
||||||
token = self.hs.handlers.registration_handler.generate_token("some_user")
|
token = self.hs.handlers.auth_handler.generate_access_token("some_user")
|
||||||
# Check that we can parse the thing with pymacaroons
|
# Check that we can parse the thing with pymacaroons
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
# The most basic of sanity checks
|
# The most basic of sanity checks
|
||||||
|
@ -47,7 +47,7 @@ class RegisterTestCase(unittest.TestCase):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
||||||
self.hs.clock.now = 5000
|
self.hs.clock.now = 5000
|
||||||
|
|
||||||
token = self.hs.handlers.registration_handler.generate_token("a_user")
|
token = self.hs.handlers.auth_handler.generate_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
def verify_gen(caveat):
|
def verify_gen(caveat):
|
Loading…
Reference in a new issue