0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-06-14 02:18:34 +02:00

Password changes (#1397)

* User API support for password changes

* Password changes in client API

* Update sytest-whitelist

* Remove debug logging

* Default logout_devices to true

* Fix deleting devices by local part
This commit is contained in:
Neil Alexander 2020-09-04 15:16:13 +01:00 committed by GitHub
parent ca8dcf46b7
commit 5076925c18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 268 additions and 26 deletions

View file

@ -5,6 +5,7 @@ type LoginType string
// The relevant login types implemented in Dendrite
const (
LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy"
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha"

View file

@ -0,0 +1,127 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type newPasswordRequest struct {
NewPassword string `json:"new_password"`
LogoutDevices bool `json:"logout_devices"`
Auth newPasswordAuth `json:"auth"`
}
type newPasswordAuth struct {
Type string `json:"type"`
Session string `json:"session"`
auth.PasswordRequest
}
func Password(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
// Check that the existing password is right.
var r newPasswordRequest
r.LogoutDevices = true
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// Retrieve or generate the sessionID
sessionID := r.Auth.Session
if sessionID == "" {
// Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength)
}
// Require password auth to change the password.
if r.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
}
}
// Check if the existing password is correct.
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change.
passwordReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: r.NewPassword,
}
passwordRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed")
return jsonerror.InternalServerError()
}
if !passwordRes.PasswordUpdated {
util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't")
return jsonerror.InternalServerError()
}
// If the request asks us to log out all other devices then
// ask the user API to do that.
if r.LogoutDevices {
logoutReq := &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
ExceptDeviceID: device.ID,
}
logoutRes := &userapi.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
}
// Return a success code.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -417,6 +417,15 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
return Password(req, userAPI, accountDB, device, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot
r0mux.Handle("/login",

View file

@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy
Can search public room list
Can get remote public room list
Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list
After changing password, can't log in with old password
After changing password, can log in with new password
After changing password, existing session still works
After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default

View file

@ -26,6 +26,7 @@ import (
type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete. An empty slice means delete all devices.
DeviceIDs []string
// The requesting device ID to exclude from deletion. This is needed
// so that a password change doesn't cause that client to be logged
// out. Only specify when DeviceIDs is empty.
ExceptDeviceID string
}
type PerformDeviceDeletionResponse struct {
@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account
}
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
type PerformPasswordUpdateResponse struct {
PasswordUpdated bool
Account *Account
}
// PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct {
Localpart string

View file

@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc
return nil
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
return nil
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}

View file

@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
ctx context.Context,
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context,
request *api.PerformDeviceCreationRequest,

View file

@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}

View file

@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile

View file

@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
})
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
err = d.accounts.updatePassword(ctx, localpart, hash)
return err
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil