mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 01:23:51 +01:00
Merge pull request #949 from matrix-org/rav/update_devices
Implement updates and deletes for devices
This commit is contained in:
commit
d34e9f93b7
13 changed files with 296 additions and 29 deletions
|
@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
|
|||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
|
@ -374,7 +375,8 @@ class AuthHandler(BaseHandler):
|
|||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None):
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
|
@ -383,9 +385,15 @@ class AuthHandler(BaseHandler):
|
|||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
||||
The device will be recorded in the table if it is not there already.
|
||||
|
||||
Args:
|
||||
user_id (str): canonical User ID
|
||||
device_id (str): the device ID to associate with the access token
|
||||
device_id (str|None): the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
A tuple of:
|
||||
The access token for the user's session.
|
||||
|
@ -397,6 +405,16 @@ class AuthHandler(BaseHandler):
|
|||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str)
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
|
@ -117,6 +117,55 @@ class DeviceHandler(BaseHandler):
|
|||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
""" Delete the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.delete_device(user_id, device_id)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
yield self.store.user_delete_access_tokens(user_id,
|
||||
device_id=device_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
content (dict): body of update request
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
def register_paths(self, method, path_patterns, callback):
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
|
|
@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
|
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
|
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id, device_id
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
|
|
|
@ -13,19 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(RestServlet):
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
|||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeviceRestServlet(RestServlet):
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
||||
|
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
|
|||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||
# It allows the client to delete access tokens, which feels like a
|
||||
# thing which merits extra auth. But if we want to do the interactive-
|
||||
# auth dance, we should really make it possible to delete more than one
|
||||
# device at a time.
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
|
|||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token = yield self.auth_handler.issue_access_token(
|
||||
user_id, device_id=device_id
|
||||
access_token, refresh_token = (
|
||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||
user_id, device_id=device_id
|
||||
)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
|
|
|
@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
|
|||
desc="get_device",
|
||||
)
|
||||
|
||||
def delete_device(self, user_id, device_id):
|
||||
"""Delete a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
return self._simple_delete_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="delete_device",
|
||||
)
|
||||
|
||||
def update_device(self, user_id, device_id, new_display_name=None):
|
||||
"""Update a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to update
|
||||
new_display_name (str|None): new displayname for device; None
|
||||
to leave unchanged
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
updates = {}
|
||||
if new_display_name is not None:
|
||||
updates["display_name"] = new_display_name
|
||||
if not updates:
|
||||
return defer.succeed(None)
|
||||
return self._simple_update_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
updatevalues=updates,
|
||||
desc="update_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""Retrieve all of a user's registered devices.
|
||||
|
|
|
@ -18,18 +18,31 @@ import re
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError, Codes
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.storage import background_updates
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegistrationStore, self).__init__(hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.register_background_index_update(
|
||||
"access_tokens_device_index",
|
||||
index_name="access_tokens_device_id",
|
||||
table="access_tokens",
|
||||
columns=["user_id", "device_id"],
|
||||
)
|
||||
|
||||
self.register_background_index_update(
|
||||
"refresh_tokens_device_index",
|
||||
index_name="refresh_tokens_device_id",
|
||||
table="refresh_tokens",
|
||||
columns=["user_id", "device_id"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||
"""Adds an access token for the given user.
|
||||
|
@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore):
|
|||
self.get_user_by_id.invalidate((user_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||
device_id=None):
|
||||
def f(txn):
|
||||
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
||||
clauses = [user_id]
|
||||
|
||||
if device_id is not None:
|
||||
sql += " AND device_id = ?"
|
||||
clauses.append(device_id)
|
||||
|
||||
if except_token_ids:
|
||||
sql += " AND id NOT IN (%s)" % (
|
||||
",".join(["?" for _ in except_token_ids]),
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('access_tokens_device_index', '{}');
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('refresh_tokens_device_index', '{}');
|
|
@ -12,11 +12,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse import types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.handlers.device
|
||||
|
||||
import synapse.storage
|
||||
from synapse import types
|
||||
from tests import unittest, utils
|
||||
|
||||
user1 = "@boris:aaa"
|
||||
|
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
self.handler = None # type: device.DeviceHandler
|
||||
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||
self.clock = None # type: utils.MockClock
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -123,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
|
|||
"last_seen_ts": 3000000,
|
||||
}, res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
# delete the device
|
||||
yield self.handler.delete_device(user1, "abc")
|
||||
|
||||
# check the device was deleted
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.get_device(user1, "abc")
|
||||
|
||||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
yield self.handler.update_device(user1, "abc", update)
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||
update)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_users(self):
|
||||
# check this works for both devices which have a recorded client_ip,
|
||||
|
|
|
@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=user_id
|
||||
)
|
||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
|
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
"password": "monkey"
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
)
|
||||
self.device_handler.check_device_registered = \
|
||||
Mock(return_value=device_id)
|
||||
|
||||
|
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertIn("refresh_token", result)
|
||||
self.auth_handler.issue_access_token.assert_called_once_with(
|
||||
user_id, device_id=device_id)
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None)
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.enable_registration = False
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
"device_id": "device2",
|
||||
"display_name": "display_name 2",
|
||||
}, res["device2"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self.store.store_device(
|
||||
"user_id", "device_id", "display_name 1"
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
)
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 2", res["display_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||
yield self.store.update_device(
|
||||
"user_id", "unknown_device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
self.assertEqual(404, cm.exception.code)
|
||||
|
|
Loading…
Reference in a new issue