0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:34:00 +01:00

Merge pull request #949 from matrix-org/rav/update_devices

Implement updates and deletes for devices
This commit is contained in:
David Baker 2016-07-26 10:49:55 +01:00 committed by GitHub
commit d34e9f93b7
13 changed files with 296 additions and 29 deletions

View file

@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
self.ldap_bind_password = hs.config.ldap_bind_password self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -374,7 +375,8 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None): def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
@ -383,9 +385,15 @@ class AuthHandler(BaseHandler):
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case. machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): canonical User ID user_id (str): canonical User ID
device_id (str): the device ID to associate with the access token device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The access token for the user's session. The access token for the user's session.
@ -397,6 +405,16 @@ class AuthHandler(BaseHandler):
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token)) defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
Args: Args:
user_id (str): user_id (str):
device_id (str) device_id (str):
Returns: Returns:
defer.Deferred: dict[str, X]: info on the device defer.Deferred: dict[str, X]: info on the device
@ -117,6 +117,55 @@ class DeviceHandler(BaseHandler):
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
defer.returnValue(device) defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(user_id,
device_id=device_id)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )

View file

@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
) )
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id( yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id registered_user_id, device_id,
login_submission.get("initial_device_display_name")
) )
) )
result = { result = {

View file

@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
import logging import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(RestServlet): class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False) releases=[], v2_alpha=False)
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
) )
defer.returnValue((200, device)) defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)

View file

@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
""" """
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token = yield self.auth_handler.issue_access_token( access_token, refresh_token = (
user_id, device_id=device_id yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
) )
refresh_token = yield self.auth_handler.issue_refresh_token(
user_id, device_id=device_id
)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,

View file

@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
desc="get_device", desc="get_device",
) )
def delete_device(self, user_id, device_id):
"""Delete a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to delete
Returns:
defer.Deferred
"""
return self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_devices_by_user(self, user_id): def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices. """Retrieve all of a user's registered devices.

View file

@ -18,18 +18,31 @@ import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore): class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, hs):
super(RegistrationStore, self).__init__(hs) super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None): def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user. """Adds an access token for the given user.
@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]): def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None):
def f(txn): def f(txn):
sql = "SELECT token FROM access_tokens WHERE user_id = ?" sql = "SELECT token FROM access_tokens WHERE user_id = ?"
clauses = [user_id] clauses = [user_id]
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_token_ids: if except_token_ids:
sql += " AND id NOT IN (%s)" % ( sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]), ",".join(["?" for _ in except_token_ids]),

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('access_tokens_device_index', '{}');

View file

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View file

@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse import types
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device import synapse.handlers.device
import synapse.storage import synapse.storage
from synapse import types
from tests import unittest, utils from tests import unittest, utils
user1 = "@boris:aaa" user1 = "@boris:aaa"
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs) super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore self.store = None # type: synapse.storage.DataStore
self.handler = None # type: device.DeviceHandler self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock self.clock = None # type: utils.MockClock
@defer.inlineCallbacks @defer.inlineCallbacks
@ -123,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
"last_seen_ts": 3000000, "last_seen_ts": 3000000,
}, res) }, res)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
res = yield self.handler.get_device(user1, "abc")
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id",
update)
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_users(self): def _record_users(self):
# check this works for both devices which have a recorded client_ip, # check this works for both devices which have a recorded client_ip,

View file

@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=user_id return_value=user_id
) )
self.auth_handler.issue_access_token = Mock(return_value=token) self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200) self.assertEquals(code, 200)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname "home_server": self.hs.hostname
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}, None) }, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.issue_access_token = Mock(return_value=token) self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
self.device_handler.check_device_registered = \ self.device_handler.check_device_registered = \
Mock(return_value=device_id) Mock(return_value=device_id)
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result) self.assertIn("refresh_token", result)
self.auth_handler.issue_access_token.assert_called_once_with( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id) user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False

View file

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"device_id": "device2", "device_id": "device2",
"display_name": "display_name 2", "display_name": "display_name 2",
}, res["device2"]) }, res["device2"])
@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield self.store.update_device(
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
"user_id", "device_id",
new_display_name="display_name 2",
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
"user_id", "unknown_device_id",
new_display_name="display_name 2",
)
self.assertEqual(404, cm.exception.code)