mirror of
https://github.com/matrix-org/dendrite
synced 2024-12-14 00:53:51 +01:00
Add device display names (#319)
This commit is contained in:
parent
8720570bb0
commit
bad701c703
7 changed files with 115 additions and 19 deletions
|
@ -40,7 +40,9 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||
-- migration to different domain names easier.
|
||||
localpart TEXT NOT NULL,
|
||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The display name, human friendlier than device_id and updatable
|
||||
display_name TEXT
|
||||
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
||||
);
|
||||
|
||||
|
@ -49,16 +51,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
|
|||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT created_ts FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id FROM device_devices WHERE localpart = $1"
|
||||
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||
|
@ -66,13 +71,12 @@ const deleteDeviceSQL = "" +
|
|||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1"
|
||||
|
||||
// TODO: List devices?
|
||||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
updateDeviceNameStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
|
@ -95,6 +99,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -110,10 +117,11 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string,
|
||||
) (*authtypes.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil {
|
||||
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authtypes.Device{
|
||||
|
@ -139,6 +147,14 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*authtypes.Device, error) {
|
||||
|
|
|
@ -75,6 +75,7 @@ func (d *Database) GetDevicesByLocalpart(
|
|||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string,
|
||||
) (dev *authtypes.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
|
@ -84,7 +85,7 @@ func (d *Database) CreateDevice(
|
|||
return err
|
||||
}
|
||||
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken)
|
||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
|
@ -99,7 +100,7 @@ func (d *Database) CreateDevice(
|
|||
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken)
|
||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
|
@ -110,6 +111,16 @@ func (d *Database) CreateDevice(
|
|||
return
|
||||
}
|
||||
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart
|
||||
// If the device doesn't exist, it will not return an error
|
||||
|
|
|
@ -16,6 +16,7 @@ package routing
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -35,6 +36,10 @@ type devicesJSON struct {
|
|||
Devices []deviceJSON `json:"devices"`
|
||||
}
|
||||
|
||||
type deviceUpdateJSON struct {
|
||||
DisplayName *string `json:"display_name"`
|
||||
}
|
||||
|
||||
// GetDeviceByID handles /device/{deviceID}
|
||||
func GetDeviceByID(
|
||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
||||
|
@ -95,3 +100,56 @@ func GetDevicesByLocalpart(
|
|||
JSON: res,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
||||
func UpdateDeviceByID(
|
||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
||||
deviceID string,
|
||||
) util.JSONResponse {
|
||||
if req.Method != "PUT" {
|
||||
return util.JSONResponse{
|
||||
Code: 405,
|
||||
JSON: jsonerror.NotFound("Bad Method"),
|
||||
}
|
||||
}
|
||||
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
return util.JSONResponse{
|
||||
Code: 404,
|
||||
JSON: jsonerror.NotFound("Unknown device"),
|
||||
}
|
||||
} else if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
if dev.UserID != device.UserID {
|
||||
return util.JSONResponse{
|
||||
Code: 403,
|
||||
JSON: jsonerror.Forbidden("device not owned by current user"),
|
||||
}
|
||||
}
|
||||
|
||||
defer req.Body.Close() // nolint: errcheck
|
||||
|
||||
payload := deviceUpdateJSON{}
|
||||
|
||||
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,8 +38,9 @@ type flow struct {
|
|||
}
|
||||
|
||||
type passwordRequest struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
|
@ -119,7 +120,7 @@ func Login(
|
|||
|
||||
// TODO: Use the device ID in the request
|
||||
dev, err := deviceDB.CreateDevice(
|
||||
req.Context(), acc.Localpart, nil, token,
|
||||
req.Context(), acc.Localpart, nil, token, r.InitialDisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
|
|
@ -60,6 +60,8 @@ type registerRequest struct {
|
|||
Admin bool `json:"admin"`
|
||||
// user-interactive auth params
|
||||
Auth authDict `json:"auth"`
|
||||
|
||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||
}
|
||||
|
||||
type authDict struct {
|
||||
|
@ -210,10 +212,10 @@ func Register(
|
|||
return util.MessageResponse(403, "HMAC incorrect")
|
||||
}
|
||||
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName)
|
||||
case authtypes.LoginTypeDummy:
|
||||
// there is nothing to do
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName)
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: 501,
|
||||
|
@ -270,10 +272,10 @@ func LegacyRegister(
|
|||
return util.MessageResponse(403, "HMAC incorrect")
|
||||
}
|
||||
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||
case authtypes.LoginTypeDummy:
|
||||
// there is nothing to do
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: 501,
|
||||
|
@ -287,6 +289,7 @@ func completeRegistration(
|
|||
accountDB *accounts.Database,
|
||||
deviceDB *devices.Database,
|
||||
username, password string,
|
||||
displayName *string,
|
||||
) util.JSONResponse {
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
|
@ -318,7 +321,7 @@ func completeRegistration(
|
|||
}
|
||||
|
||||
// // TODO: Use the device ID in the request.
|
||||
dev, err := deviceDB.CreateDevice(ctx, username, nil, token)
|
||||
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: 500,
|
||||
|
|
|
@ -364,6 +364,13 @@ func Setup(
|
|||
}),
|
||||
).Methods("GET")
|
||||
|
||||
r0mux.Handle("/devices/{deviceID}",
|
||||
common.MakeAuthAPI("device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||
vars := mux.Vars(req)
|
||||
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
|
||||
}),
|
||||
).Methods("PUT", "OPTIONS")
|
||||
|
||||
// Stub implementations for sytest
|
||||
r0mux.Handle("/events",
|
||||
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
|
|
|
@ -87,7 +87,7 @@ func main() {
|
|||
}
|
||||
|
||||
device, err := deviceDB.CreateDevice(
|
||||
context.Background(), *username, nil, *accessToken,
|
||||
context.Background(), *username, nil, *accessToken, nil,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
|
|
Loading…
Reference in a new issue