0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-11-16 06:41:06 +01:00

Password changes (#1397)

* User API support for password changes

* Password changes in client API

* Update sytest-whitelist

* Remove debug logging

* Default logout_devices to true

* Fix deleting devices by local part
This commit is contained in:
Neil Alexander 2020-09-04 15:16:13 +01:00 committed by GitHub
parent ca8dcf46b7
commit 5076925c18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 268 additions and 26 deletions

View file

@ -5,6 +5,7 @@ type LoginType string
// The relevant login types implemented in Dendrite // The relevant login types implemented in Dendrite
const ( const (
LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy" LoginTypeDummy = "m.login.dummy"
LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"

View file

@ -0,0 +1,127 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type newPasswordRequest struct {
NewPassword string `json:"new_password"`
LogoutDevices bool `json:"logout_devices"`
Auth newPasswordAuth `json:"auth"`
}
type newPasswordAuth struct {
Type string `json:"type"`
Session string `json:"session"`
auth.PasswordRequest
}
func Password(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
// Check that the existing password is right.
var r newPasswordRequest
r.LogoutDevices = true
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// Retrieve or generate the sessionID
sessionID := r.Auth.Session
if sessionID == "" {
// Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength)
}
// Require password auth to change the password.
if r.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
}
}
// Check if the existing password is correct.
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change.
passwordReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: r.NewPassword,
}
passwordRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed")
return jsonerror.InternalServerError()
}
if !passwordRes.PasswordUpdated {
util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't")
return jsonerror.InternalServerError()
}
// If the request asks us to log out all other devices then
// ask the user API to do that.
if r.LogoutDevices {
logoutReq := &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
ExceptDeviceID: device.ID,
}
logoutRes := &userapi.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
}
// Return a success code.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -417,6 +417,15 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
return Password(req, userAPI, accountDB, device, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot // Stub endpoints required by Riot
r0mux.Handle("/login", r0mux.Handle("/login",

View file

@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy
Can search public room list Can search public room list
Can get remote public room list Can get remote public room list
Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list
After changing password, can't log in with old password
After changing password, can log in with new password
After changing password, existing session still works
After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default

View file

@ -26,6 +26,7 @@ import (
type UserInternalAPI interface { type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string UserID string
// The devices to delete. An empty slice means delete all devices. // The devices to delete. An empty slice means delete all devices.
DeviceIDs []string DeviceIDs []string
// The requesting device ID to exclude from deletion. This is needed
// so that a password change doesn't cause that client to be logged
// out. Only specify when DeviceIDs is empty.
ExceptDeviceID string
} }
type PerformDeviceDeletionResponse struct { type PerformDeviceDeletionResponse struct {
@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account Account *Account
} }
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
type PerformPasswordUpdateResponse struct {
PasswordUpdated bool
Account *Account
}
// PerformDeviceCreationRequest is the request for PerformDeviceCreation // PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct { type PerformDeviceCreationRequest struct {
Localpart string Localpart string

View file

@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc res.Account = acc
return nil return nil
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
return nil
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart, "localpart": req.Localpart,
@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }

View file

@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformPasswordUpdate(
ctx context.Context,
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceCreation( func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context, ctx context.Context,
request *api.PerformDeviceCreationRequest, request *api.PerformDeviceCreationRequest,

View file

@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath, internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{} request := api.PerformDeviceCreationRequest{}

View file

@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile

View file

@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')" "SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return return
} }
@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {

View file

@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile // CreateGuestAccount makes a new guest account and creates an empty profile
// for this account. // for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts" "SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return return
} }
@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {

View file

@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
}) })
} }
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
err = d.accounts.updatePassword(ctx, localpart, hash)
return err
}
// CreateGuestAccount makes a new guest account and creates an empty profile // CreateGuestAccount makes a new guest account and creates an empty profile
// for this account. // for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
} }

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil