0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 06:51:46 +01:00

Add device_id support to /login

Add a 'devices' table to the storage, as well as a 'device_id' column to
refresh_tokens.

Allow the client to pass a device_id, and initial_device_display_name, to
/login. If login is successful, then register the device in the devices table
if it wasn't known already. If no device_id was supplied, make one up.

Associate the device_id with the access token and refresh token, so that we can
get at it again later. Ensure that the device_id is copied from the refresh
token to the access_token when the token is refreshed.
This commit is contained in:
Richard van der Hoff 2016-07-15 13:19:07 +01:00
parent 93efcb8526
commit f863a52cea
12 changed files with 354 additions and 31 deletions

View file

@ -361,7 +361,7 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id, device_id=None):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
@ -372,6 +372,7 @@ class AuthHandler(BaseHandler):
Args: Args:
user_id (str): canonical User ID user_id (str): canonical User ID
device_id (str): the device ID to associate with the access token
Returns: Returns:
A tuple of: A tuple of:
The access token for the user's session. The access token for the user's session.
@ -380,9 +381,9 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id, device_id)
defer.returnValue((access_token, refresh_token)) defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -638,15 +639,17 @@ class AuthHandler(BaseHandler):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id): def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None, def generate_access_token(self, user_id, extra_caveats=None,

View file

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import StoreError
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a device ID.")

View file

@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.servername = hs.config.server_name self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet):
user_id=user_id, user_id=user_id,
password=login_submission["password"], password=login_submission["password"],
) )
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.auth_handler auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id: if registered_user_id:
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(registered_user_id) yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id
)
) )
result = { result = {
"user_id": registered_user_id, "user_id": registered_user_id,
@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = ( user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/saml2", releases=()) PATTERNS = client_path_patterns("/login/saml2", releases=())

View file

@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token( refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token) old_refresh_token, auth_handler.generate_refresh_token
new_access_token = yield auth_handler.issue_access_token(user_id) )
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, { defer.returnValue((200, {
"access_token": new_access_token, "access_token": new_access_token,
"refresh_token": new_refresh_token, "refresh_token": new_refresh_token,

View file

@ -25,6 +25,7 @@ from twisted.enterprise import adbapi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.handlers.device import DeviceHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
@ -92,6 +93,7 @@ class HomeServer(object):
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'auth_handler', 'auth_handler',
'device_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -197,6 +199,9 @@ class HomeServer(object):
def build_auth_handler(self): def build_auth_handler(self):
return AuthHandler(self) return AuthHandler(self)
def build_device_handler(self):
return DeviceHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.devices import DeviceStore
from .appservice import ( from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
) )
@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore, EventPushActionsStore,
OpenIdStore, OpenIdStore,
ClientIpStore, ClientIpStore,
DeviceStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):

View file

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name,
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not
Args:
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
device
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns:
defer.Deferred
Raises:
StoreError: if ignore_if_known is False and the device was already
known
"""
try:
yield self._simple_insert(
"devices",
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name
},
desc="store_device",
or_ignore=ignore_if_known,
)
except Exception as e:
logger.error("store_device with device_id=%s failed: %s",
device_id, e)
raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a namedtuple containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)

View file

@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token): def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user. """Adds an access token for the given user.
Args: Args:
user_id (str): The user ID. user_id (str): The user ID.
token (str): The new access token to add. token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore):
{ {
"id": next_id, "id": next_id,
"user_id": user_id, "user_id": user_id,
"token": token "token": token,
"device_id": device_id,
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token): def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user. """Adds a refresh token for the given user.
Args: Args:
user_id (str): The user ID. user_id (str): The user ID.
token (str): The new refresh token to add. token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore):
{ {
"id": next_id, "id": next_id,
"user_id": user_id, "user_id": user_id,
"token": token "token": token,
"device_id": device_id,
}, },
desc="add_refresh_token_to_user", desc="add_refresh_token_to_user",
) )
@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore):
) )
def exchange_refresh_token(self, refresh_token, token_generator): def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new access token and refresh token. """Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single Doing so invalidates the old refresh token - refresh tokens are single
use. use.
Args: Args:
token (str): The refresh token of a user. refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This user ID, returns a unique refresh token for that user. This
function must never return the same value twice. function must never return the same value twice.
Returns: Returns:
tuple of (user_id, refresh_token) tuple of (user_id, new_refresh_token, device_id)
Raises: Raises:
StoreError if no user was found with that refresh token. StoreError if no user was found with that refresh token.
""" """
@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore):
) )
def _exchange_refresh_token(self, txn, old_token, token_generator): def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,)) txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
if not rows: if not rows:
raise StoreError(403, "Did not recognize refresh token") raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"] user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that # TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id. # macaroon.user_id == user_id.
@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,)) txn.execute(sql, (new_token, old_token,))
return user_id, new_token return user_id, new_token, device_id
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE devices (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
display_name TEXT,
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
);

View file

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT;

View file

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.handlers.device import DeviceHandler
from tests import unittest
from tests.utils import setup_test_homeserver
class DeviceHandlers(object):
def __init__(self, hs):
self.device_handler = DeviceHandler(hs)
class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = handlers = DeviceHandlers(self.hs)
self.handler = handlers.device_handler
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="theresa",
device_id=None,
initial_device_display_name="display"
)
dev = yield self.handler.store.get_device("theresa", device_id)
self.assertEqual(dev["display_name"], "display")

View file

@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"BcDeFgHiJkLmNoPqRsTuVwXyZa" "BcDeFgHiJkLmNoPqRsTuVwXyZa"
] ]
self.pwhash = "{xx1}123456789" self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register(self): def test_register(self):
@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
result = yield self.store.get_user_by_access_token(self.tokens[1]) result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"name": self.user_id, "name": self.user_id,
"device_id": self.device_id,
}, },
result result
) )
@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_exchange_refresh_token_valid(self): def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32) uid = stringutils.random_string(32)
device_id = stringutils.random_string(16)
generator = TokenGenerator() generator = TokenGenerator()
last_token = generator.generate(uid) last_token = generator.generate(uid)
self.db_pool.runQuery( self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", "INSERT INTO refresh_tokens(user_id, token, device_id) "
(uid, last_token,)) "VALUES(?,?,?)",
(uid, last_token, device_id))
(found_user_id, refresh_token) = yield self.store.exchange_refresh_token( (found_user_id, refresh_token, device_id) = \
last_token, generator.generate) yield self.store.exchange_refresh_token(last_token,
generator.generate)
self.assertEqual(uid, found_user_id) self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
"SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?",
self.assertEqual([(refresh_token,)], rows) (uid, ))
self.assertEqual([(refresh_token, device_id)], rows)
# We issued token 1, then exchanged it for token 2 # We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,) expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token) self.assertEqual(expected_refresh_token, refresh_token)