0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 13:13:50 +01:00

/tokenrefresh POST endpoint

This allows refresh tokens to be exchanged for (access_token,
refresh_token).

It also starts issuing them on login, though no clients currently
interpret them.
This commit is contained in:
Daniel Wagner-Hall 2015-08-20 16:21:35 +01:00
parent 13a6517d89
commit cecbd636e9
9 changed files with 232 additions and 8 deletions

View file

@ -279,15 +279,18 @@ class AuthHandler(BaseHandler):
user_id (str): User ID user_id (str): User ID
password (str): Password password (str): Password
Returns: Returns:
A tuple of:
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
yield self._check_password(user_id, password) yield self._check_password(user_id, password)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
defer.returnValue(token) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -304,11 +307,16 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id):
reg_handler = self.hs.get_handlers().registration_handler access_token = self.generate_access_token(user_id)
access_token = reg_handler.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id): def generate_access_token(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location = self.hs.config.server_name, location = self.hs.config.server_name,
@ -323,6 +331,23 @@ class AuthHandler(BaseHandler):
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16))
return m.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location = self.hs.config.server_name,
identifier = "key",
key = self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())

View file

@ -78,13 +78,15 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission["user"] = UserID.create( login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string() login_submission["user"], self.hs.hostname).to_string()
token = yield self.handlers.auth_handler.login_with_password( auth_handler = self.handlers.auth_handler
access_token, refresh_token = yield auth_handler.login_with_password(
user_id=login_submission["user"], user_id=login_submission["user"],
password=login_submission["password"]) password=login_submission["password"])
result = { result = {
"user_id": login_submission["user"], # may have changed "user_id": login_submission["user"], # may have changed
"access_token": token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }

View file

@ -21,6 +21,7 @@ from . import (
auth, auth,
receipts, receipts,
keys, keys,
tokenrefresh,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource):
auth.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)

View file

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request
class TokenRefreshRestServlet(RestServlet):
"""
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
PATTERN = client_v2_pattern("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_dict_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}))
except KeyError:
raise SynapseError(400, "Missing required key 'refresh_token'.")
except StoreError:
raise AuthError(403, "Did not recognize refresh token")
def register_servlets(hs, http_server):
TokenRefreshRestServlet(hs).register(http_server)

View file

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 22 SCHEMA_VERSION = 23
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -181,6 +181,7 @@ class SQLBaseStore(object):
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)

View file

@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
Raises:
StoreError if there was a problem adding this.
"""
next_id = yield self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash): def register(self, user_id, token, password_hash):
"""Attempts to register an account. """Attempts to register an account.
@ -152,6 +174,46 @@ class RegistrationStore(SQLBaseStore):
token token
) )
def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new access token and refresh token.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, refresh_token)
Raises:
StoreError if no user was found with that refresh token.
"""
return self.runInteraction(
"exchange_refresh_token",
self._exchange_refresh_token,
refresh_token,
token_generator
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
new_token = token_generator(user_id)
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
return user_id, new_token
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(

View file

@ -0,0 +1,21 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS refresh_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT,
token TEXT NOT NULL,
user_id TEXT NOT NULL,
UNIQUE (token)
);

View file

@ -17,7 +17,9 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.registration import RegistrationStore from synapse.storage.registration import RegistrationStore
from synapse.util import stringutils
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
self.db_pool = hs.get_db_pool()
self.store = RegistrationStore(hs) self.store = RegistrationStore(hs)
@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertTrue("token_id" in result) self.assertTrue("token_id" in result)
@defer.inlineCallbacks
def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, last_token,))
(found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
last_token, generator.generate)
self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery(
"SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
self.assertEqual([(refresh_token,)], rows)
# We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token)
@defer.inlineCallbacks
def test_exchange_refresh_token_none(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_exchange_refresh_token_invalid(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
wrong_token = "%s-wrong" % (last_token,)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, wrong_token,))
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
class TokenGenerator:
def __init__(self):
self._last_issued_token = 0
def generate(self, user_id):
self._last_issued_token += 1
return u"%s-%d" % (user_id, self._last_issued_token,)