mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-19 06:22:04 +01:00
Implement updating devices
You can update the displayname of devices now.
This commit is contained in:
parent
436bffd15f
commit
012b4c1913
5 changed files with 120 additions and 9 deletions
|
@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler):
|
|||
yield self.store.user_delete_access_tokens(user_id,
|
||||
device_id=device_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
content (dict): body of update request
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
|
|
|
@ -13,19 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(RestServlet):
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
|||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeviceRestServlet(RestServlet):
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
||||
|
@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet):
|
|||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore):
|
|||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to retrieve
|
||||
device_id (str): The ID of the device to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
|
@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore):
|
|||
desc="delete_device",
|
||||
)
|
||||
|
||||
def update_device(self, user_id, device_id, new_display_name=None):
|
||||
"""Update a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to update
|
||||
new_display_name (str|None): new displayname for device; None
|
||||
to leave unchanged
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
updates = {}
|
||||
if new_display_name is not None:
|
||||
updates["display_name"] = new_display_name
|
||||
if not updates:
|
||||
return defer.succeed(None)
|
||||
return self._simple_update_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
updatevalues=updates,
|
||||
desc="update_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""Retrieve all of a user's registered devices.
|
||||
|
|
|
@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase):
|
|||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
yield self.handler.update_device(user1, "abc", update)
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||
update)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_users(self):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
"device_id": "device2",
|
||||
"display_name": "display_name 2",
|
||||
}, res["device2"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self.store.store_device(
|
||||
"user_id", "device_id", "display_name 1"
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
)
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 2", res["display_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||
yield self.store.update_device(
|
||||
"user_id", "unknown_device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
self.assertEqual(404, cm.exception.code)
|
||||
|
|
Loading…
Add table
Reference in a new issue