mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-21 01:31:51 +01:00
Merge pull request #240 from matrix-org/refresh
/tokenrefresh POST endpoint
This commit is contained in:
commit
b1e35eabf2
19 changed files with 303 additions and 76 deletions
|
@ -361,7 +361,7 @@ class Auth(object):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass # normal users won't have the user_id query parameter set.
|
pass # normal users won't have the user_id query parameter set.
|
||||||
|
|
||||||
user_info = yield self.get_user_by_token(access_token)
|
user_info = yield self.get_user_by_access_token(access_token)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
device_id = user_info["device_id"]
|
device_id = user_info["device_id"]
|
||||||
token_id = user_info["token_id"]
|
token_id = user_info["token_id"]
|
||||||
|
@ -390,7 +390,7 @@ class Auth(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_access_token(self, token):
|
||||||
""" Get a registered user's ID.
|
""" Get a registered user's ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -401,7 +401,7 @@ class Auth(object):
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
ret = yield self.store.get_user_by_token(token)
|
ret = yield self.store.get_user_by_access_token(token)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||||
|
|
|
@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import pymacaroons
|
||||||
import simplejson
|
import simplejson
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
@ -278,18 +279,18 @@ class AuthHandler(BaseHandler):
|
||||||
user_id (str): User ID
|
user_id (str): User ID
|
||||||
password (str): Password
|
password (str): Password
|
||||||
Returns:
|
Returns:
|
||||||
|
A tuple of:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
|
The refresh token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
yield self._check_password(user_id, password)
|
yield self._check_password(user_id, password)
|
||||||
|
|
||||||
reg_handler = self.hs.get_handlers().registration_handler
|
|
||||||
access_token = reg_handler.generate_token(user_id)
|
|
||||||
logger.info("Logging in user %s", user_id)
|
logger.info("Logging in user %s", user_id)
|
||||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
access_token = yield self.issue_access_token(user_id)
|
||||||
defer.returnValue(access_token)
|
refresh_token = yield self.issue_refresh_token(user_id)
|
||||||
|
defer.returnValue((access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password(self, user_id, password):
|
def _check_password(self, user_id, password):
|
||||||
|
@ -304,6 +305,45 @@ class AuthHandler(BaseHandler):
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def issue_access_token(self, user_id):
|
||||||
|
access_token = self.generate_access_token(user_id)
|
||||||
|
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||||
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def issue_refresh_token(self, user_id):
|
||||||
|
refresh_token = self.generate_refresh_token(user_id)
|
||||||
|
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||||
|
defer.returnValue(refresh_token)
|
||||||
|
|
||||||
|
def generate_access_token(self, user_id):
|
||||||
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
now = self.hs.get_clock().time_msec()
|
||||||
|
expiry = now + (60 * 60 * 1000)
|
||||||
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def generate_refresh_token(self, user_id):
|
||||||
|
m = self._generate_base_macaroon(user_id)
|
||||||
|
m.add_first_party_caveat("type = refresh")
|
||||||
|
# Important to add a nonce, because otherwise every refresh token for a
|
||||||
|
# user will be the same.
|
||||||
|
m.add_first_party_caveat("nonce = %s" % (
|
||||||
|
stringutils.random_string_with_symbols(16),
|
||||||
|
))
|
||||||
|
return m.serialize()
|
||||||
|
|
||||||
|
def _generate_base_macaroon(self, user_id):
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location=self.hs.config.server_name,
|
||||||
|
identifier="key",
|
||||||
|
key=self.hs.config.macaroon_secret_key)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
|
return macaroon
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword):
|
def set_password(self, user_id, newpassword):
|
||||||
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
||||||
|
|
|
@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import logging
|
import logging
|
||||||
import pymacaroons
|
|
||||||
import urllib
|
import urllib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
400, "Invalid user localpart for this application service.",
|
400, "Invalid user localpart for this application service.",
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
token = self.generate_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
try:
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_token(self, user_id):
|
|
||||||
macaroon = pymacaroons.Macaroon(
|
|
||||||
location=self.hs.config.server_name,
|
|
||||||
identifier="key",
|
|
||||||
key=self.hs.config.macaroon_secret_key)
|
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
|
||||||
macaroon.add_first_party_caveat("type = access")
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
expiry = now + (60 * 60 * 1000)
|
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
|
||||||
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def _generate_user_id(self):
|
def _generate_user_id(self):
|
||||||
return "-" + stringutils.random_string(18)
|
return "-" + stringutils.random_string(18)
|
||||||
|
|
||||||
|
@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
|
def auth_handler(self):
|
||||||
|
return self.hs.get_handlers().auth_handler
|
||||||
|
|
|
@ -85,13 +85,15 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_id = UserID.create(
|
user_id = UserID.create(
|
||||||
user_id, self.hs.hostname).to_string()
|
user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
token = yield self.handlers.auth_handler.login_with_password(
|
auth_handler = self.handlers.auth_handler
|
||||||
|
access_token, refresh_token = yield auth_handler.login_with_password(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"])
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": login_submission["user"], # may have changed
|
||||||
"access_token": token,
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from . import (
|
||||||
auth,
|
auth,
|
||||||
receipts,
|
receipts,
|
||||||
keys,
|
keys,
|
||||||
|
tokenrefresh,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource):
|
||||||
auth.register_servlets(hs, client_resource)
|
auth.register_servlets(hs, client_resource)
|
||||||
receipts.register_servlets(hs, client_resource)
|
receipts.register_servlets(hs, client_resource)
|
||||||
keys.register_servlets(hs, client_resource)
|
keys.register_servlets(hs, client_resource)
|
||||||
|
tokenrefresh.register_servlets(hs, client_resource)
|
||||||
|
|
56
synapse/rest/client/v2_alpha/tokenrefresh.py
Normal file
56
synapse/rest/client/v2_alpha/tokenrefresh.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
|
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRefreshRestServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
Exchanges refresh tokens for a pair of an access token and a new refresh
|
||||||
|
token.
|
||||||
|
"""
|
||||||
|
PATTERN = client_v2_pattern("/tokenrefresh")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(TokenRefreshRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
body = parse_json_dict_from_request(request)
|
||||||
|
try:
|
||||||
|
old_refresh_token = body["refresh_token"]
|
||||||
|
auth_handler = self.hs.get_handlers().auth_handler
|
||||||
|
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||||
|
old_refresh_token, auth_handler.generate_refresh_token)
|
||||||
|
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"access_token": new_access_token,
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
}))
|
||||||
|
except KeyError:
|
||||||
|
raise SynapseError(400, "Missing required key 'refresh_token'.")
|
||||||
|
except StoreError:
|
||||||
|
raise AuthError(403, "Did not recognize refresh token")
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
TokenRefreshRestServlet(hs).register(http_server)
|
|
@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 22
|
SCHEMA_VERSION = 23
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
|
@ -181,6 +181,7 @@ class SQLBaseStore(object):
|
||||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||||
|
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
|
|
@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore):
|
||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def add_refresh_token_to_user(self, user_id, token):
|
||||||
|
"""Adds a refresh token for the given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user ID.
|
||||||
|
token (str): The new refresh token to add.
|
||||||
|
Raises:
|
||||||
|
StoreError if there was a problem adding this.
|
||||||
|
"""
|
||||||
|
next_id = yield self._refresh_tokens_id_gen.get_next()
|
||||||
|
|
||||||
|
yield self._simple_insert(
|
||||||
|
"refresh_tokens",
|
||||||
|
{
|
||||||
|
"id": next_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"token": token
|
||||||
|
},
|
||||||
|
desc="add_refresh_token_to_user",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, user_id, token, password_hash):
|
def register(self, user_id, token, password_hash):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
@ -132,10 +154,10 @@ class RegistrationStore(SQLBaseStore):
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.get_user_by_token.invalidate((r,))
|
self.get_user_by_access_token.invalidate((r,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_access_token(self, token):
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,11 +169,51 @@ class RegistrationStore(SQLBaseStore):
|
||||||
StoreError if no user was found.
|
StoreError if no user was found.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_user_by_token",
|
"get_user_by_access_token",
|
||||||
self._query_for_auth,
|
self._query_for_auth,
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exchange_refresh_token(self, refresh_token, token_generator):
|
||||||
|
"""Exchange a refresh token for a new access token and refresh token.
|
||||||
|
|
||||||
|
Doing so invalidates the old refresh token - refresh tokens are single
|
||||||
|
use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The refresh token of a user.
|
||||||
|
token_generator (fn: str -> str): Function which, when given a
|
||||||
|
user ID, returns a unique refresh token for that user. This
|
||||||
|
function must never return the same value twice.
|
||||||
|
Returns:
|
||||||
|
tuple of (user_id, refresh_token)
|
||||||
|
Raises:
|
||||||
|
StoreError if no user was found with that refresh token.
|
||||||
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
|
"exchange_refresh_token",
|
||||||
|
self._exchange_refresh_token,
|
||||||
|
refresh_token,
|
||||||
|
token_generator
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exchange_refresh_token(self, txn, old_token, token_generator):
|
||||||
|
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
|
||||||
|
txn.execute(sql, (old_token,))
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
if not rows:
|
||||||
|
raise StoreError(403, "Did not recognize refresh token")
|
||||||
|
user_id = rows[0]["user_id"]
|
||||||
|
|
||||||
|
# TODO(danielwh): Maybe perform a validation on the macaroon that
|
||||||
|
# macaroon.user_id == user_id.
|
||||||
|
|
||||||
|
new_token = token_generator(user_id)
|
||||||
|
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
|
||||||
|
txn.execute(sql, (new_token, old_token,))
|
||||||
|
|
||||||
|
return user_id, new_token
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def is_server_admin(self, user):
|
def is_server_admin(self, user):
|
||||||
res = yield self._simple_select_one_onecol(
|
res = yield self._simple_select_one_onecol(
|
||||||
|
|
21
synapse/storage/schema/delta/23/refresh_tokens.sql
Normal file
21
synapse/storage/schema/delta/23/refresh_tokens.sql
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
UNIQUE (token)
|
||||||
|
);
|
|
@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
"token_id": "ditto",
|
"token_id": "ditto",
|
||||||
"admin": False
|
"admin": False
|
||||||
}
|
}
|
||||||
self.store.get_user_by_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
"token_id": "ditto",
|
"token_id": "ditto",
|
||||||
"admin": False
|
"admin": False
|
||||||
}
|
}
|
||||||
self.store.get_user_by_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
|
@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_get_user_by_req_appservice_valid_token(self):
|
def test_get_user_by_req_appservice_valid_token(self):
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self):
|
def test_get_user_by_req_appservice_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_get_user_by_req_appservice_missing_token(self):
|
def test_get_user_by_req_appservice_missing_token(self):
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
|
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
app_service.is_interested_in_user = Mock(return_value=False)
|
app_service.is_interested_in_user = Mock(return_value=False)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
|
|
|
@ -16,27 +16,27 @@
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.auth import AuthHandler
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
class RegisterHandlers(object):
|
class AuthHandlers(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.registration_handler = RegistrationHandler(hs)
|
self.auth_handler = AuthHandler(hs)
|
||||||
|
|
||||||
|
|
||||||
class RegisterTestCase(unittest.TestCase):
|
class AuthTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(handlers=None)
|
||||||
self.hs.handlers = RegisterHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
||||||
|
|
||||||
token = self.hs.handlers.registration_handler.generate_token("some_user")
|
token = self.hs.handlers.auth_handler.generate_access_token("some_user")
|
||||||
# Check that we can parse the thing with pymacaroons
|
# Check that we can parse the thing with pymacaroons
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
# The most basic of sanity checks
|
# The most basic of sanity checks
|
||||||
|
@ -47,7 +47,7 @@ class RegisterTestCase(unittest.TestCase):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
||||||
self.hs.clock.now = 5000
|
self.hs.clock.now = 5000
|
||||||
|
|
||||||
token = self.hs.handlers.registration_handler.generate_token("a_user")
|
token = self.hs.handlers.auth_handler.generate_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
def verify_gen(caveat):
|
def verify_gen(caveat):
|
|
@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase):
|
||||||
return defer.succeed([])
|
return defer.succeed([])
|
||||||
self.datastore.get_presence_list = get_presence_list
|
self.datastore.get_presence_list = get_presence_list
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(myid),
|
"user": UserID.from_string(myid),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
|
@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase):
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
room_member_handler = hs.handlers.room_member_handler = Mock(
|
room_member_handler = hs.handlers.room_member_handler = Mock(
|
||||||
spec=[
|
spec=[
|
||||||
|
@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.datastore.has_presence_state = has_presence_state
|
self.datastore.has_presence_state = has_presence_state
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(myid),
|
"user": UserID.from_string(myid),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
|
@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
presence.register_servlets(hs, self.mock_resource)
|
presence.register_servlets(hs, self.mock_resource)
|
||||||
|
|
||||||
|
|
|
@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||||
|
|
||||||
self.auth_user_id = self.user_id
|
self.auth_user_id = self.user_id
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
|
@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase):
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
|
@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
|
@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase):
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_token = _get_user_by_token
|
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
|
@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase):
|
||||||
self.mock_resource = None
|
self.mock_resource = None
|
||||||
self.auth_user_id = None
|
self.auth_user_id = None
|
||||||
|
|
||||||
def mock_get_user_by_token(self, token=None):
|
def mock_get_user_by_access_token(self, token=None):
|
||||||
return self.auth_user_id
|
return self.auth_user_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
|
||||||
resource_for_federation=self.mock_resource,
|
resource_for_federation=self.mock_resource,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_token(token=None):
|
def _get_user_by_access_token(token=None):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.USER_ID),
|
"user": UserID.from_string(self.USER_ID),
|
||||||
"admin": False,
|
"admin": False,
|
||||||
"device_id": None,
|
"device_id": None,
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
}
|
}
|
||||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
hs.get_auth().get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
for r in self.TO_REGISTER:
|
for r in self.TO_REGISTER:
|
||||||
r.register_servlets(hs, self.mock_resource)
|
r.register_servlets(hs, self.mock_resource)
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage.registration import RegistrationStore
|
from synapse.storage.registration import RegistrationStore
|
||||||
|
from synapse.util import stringutils
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver()
|
||||||
|
self.db_pool = hs.get_db_pool()
|
||||||
|
|
||||||
self.store = RegistrationStore(hs)
|
self.store = RegistrationStore(hs)
|
||||||
|
|
||||||
|
@ -46,7 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
(yield self.store.get_user_by_id(self.user_id))
|
(yield self.store.get_user_by_id(self.user_id))
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self.store.get_user_by_token(self.tokens[0])
|
result = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
|
@ -64,7 +67,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
|
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
|
||||||
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
|
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
|
||||||
|
|
||||||
result = yield self.store.get_user_by_token(self.tokens[1])
|
result = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
|
@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertTrue("token_id" in result)
|
self.assertTrue("token_id" in result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_exchange_refresh_token_valid(self):
|
||||||
|
uid = stringutils.random_string(32)
|
||||||
|
generator = TokenGenerator()
|
||||||
|
last_token = generator.generate(uid)
|
||||||
|
|
||||||
|
self.db_pool.runQuery(
|
||||||
|
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
|
||||||
|
(uid, last_token,))
|
||||||
|
|
||||||
|
(found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
|
||||||
|
last_token, generator.generate)
|
||||||
|
self.assertEqual(uid, found_user_id)
|
||||||
|
|
||||||
|
rows = yield self.db_pool.runQuery(
|
||||||
|
"SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
|
||||||
|
self.assertEqual([(refresh_token,)], rows)
|
||||||
|
# We issued token 1, then exchanged it for token 2
|
||||||
|
expected_refresh_token = u"%s-%d" % (uid, 2,)
|
||||||
|
self.assertEqual(expected_refresh_token, refresh_token)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_exchange_refresh_token_none(self):
|
||||||
|
uid = stringutils.random_string(32)
|
||||||
|
generator = TokenGenerator()
|
||||||
|
last_token = generator.generate(uid)
|
||||||
|
|
||||||
|
with self.assertRaises(StoreError):
|
||||||
|
yield self.store.exchange_refresh_token(last_token, generator.generate)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_exchange_refresh_token_invalid(self):
|
||||||
|
uid = stringutils.random_string(32)
|
||||||
|
generator = TokenGenerator()
|
||||||
|
last_token = generator.generate(uid)
|
||||||
|
wrong_token = "%s-wrong" % (last_token,)
|
||||||
|
|
||||||
|
self.db_pool.runQuery(
|
||||||
|
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
|
||||||
|
(uid, wrong_token,))
|
||||||
|
|
||||||
|
with self.assertRaises(StoreError):
|
||||||
|
yield self.store.exchange_refresh_token(last_token, generator.generate)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenGenerator:
|
||||||
|
def __init__(self):
|
||||||
|
self._last_issued_token = 0
|
||||||
|
|
||||||
|
def generate(self, user_id):
|
||||||
|
self._last_issued_token += 1
|
||||||
|
return u"%s-%d" % (user_id, self._last_issued_token,)
|
||||||
|
|
|
@ -277,7 +277,7 @@ class MemoryDataStore(object):
|
||||||
raise StoreError(400, "User in use.")
|
raise StoreError(400, "User in use.")
|
||||||
self.tokens_to_users[token] = user_id
|
self.tokens_to_users[token] = user_id
|
||||||
|
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_access_token(self, token):
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"name": self.tokens_to_users[token],
|
"name": self.tokens_to_users[token],
|
||||||
|
|
Loading…
Add table
Reference in a new issue