diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go index 65112cd84..62932b65f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go @@ -57,13 +57,18 @@ const selectDeviceByTokenSQL = "" + const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" +const deleteDevicesByLocalpartSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1" + // TODO: List devices? type devicesStatements struct { - insertDeviceStmt *sql.Stmt - selectDeviceByTokenStmt *sql.Stmt - deleteDeviceStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + insertDeviceStmt *sql.Stmt + selectDeviceByTokenStmt *sql.Stmt + deleteDeviceStmt *sql.Stmt + deleteDevicesByLocalpartStmt *sql.Stmt + + serverName gomatrixserverlib.ServerName } func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { @@ -80,6 +85,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { return } + if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { + return + } s.serverName = server return } @@ -110,6 +118,14 @@ func (s *devicesStatements) deleteDevice( return err } +func (s *devicesStatements) deleteDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart string, +) error { + stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err +} + func (s *devicesStatements) selectDeviceByToken( ctx context.Context, accessToken string, ) (*authtypes.Device, error) { diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go index ea7d87383..c100e8f58 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go @@ -109,3 +109,17 @@ func (d *Database) RemoveDevice( return nil }) } + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/logout.go b/src/github.com/matrix-org/dendrite/clientapi/routing/logout.go index ff214fe57..d03e7957f 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/logout.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/logout.go @@ -50,3 +50,22 @@ func Logout( JSON: struct{}{}, } } + +// LogoutAll handles POST /logout/all +func LogoutAll( + req *http.Request, deviceDB *devices.Database, device *authtypes.Device, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + + if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: 200, + JSON: struct{}{}, + } +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 0b9e4172a..04c183f32 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -160,6 +160,12 @@ func Setup( }), ).Methods("POST", "OPTIONS") + r0mux.Handle("/logout/all", + common.MakeAuthAPI("logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return LogoutAll(req, deviceDB, device) + }), + ).Methods("POST", "OPTIONS") + // Stub endpoints required by Riot r0mux.Handle("/login",