diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go new file mode 100644 index 000000000..0aa736245 --- /dev/null +++ b/clientapi/routing/key_backup.go @@ -0,0 +1,173 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type keyBackupVersion struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` +} + +type keyBackupVersionCreateResponse struct { + Version string `json:"version"` +} + +type keyBackupVersionResponse struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + +// Create a new key backup. Request must contain a `keyBackupVersion`. Returns a `keyBackupVersionCreateResponse`. +// Implements POST /_matrix/client/r0/room_keys/version +func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: "", + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// KeyBackupVersion returns the key backup version specified. If `version` is empty, the latest `keyBackupVersionResponse` is returned. +// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} +func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var queryResp userapi.QueryKeyBackupResponse + userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp) + if queryResp.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) + } + if !queryResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("version not found"), + } + } + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionResponse{ + Algorithm: queryResp.Algorithm, + AuthData: queryResp.AuthData, + Count: queryResp.Count, + ETag: queryResp.ETag, + Version: queryResp.Version, + }, + } +} + +// Modify the auth data of a key backup. Version must not be empty. Request must contain a `keyBackupVersion` +// Implements PUT /_matrix/client/r0/room_keys/version/{version} +func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var kb keyBackupVersion + resErr := httputil.UnmarshalJSONRequest(req, &kb) + if resErr != nil { + return *resErr + } + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + AuthData: kb.AuthData, + Algorithm: kb.Algorithm, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} + +// Delete a version of key backup. Version must not be empty. If the key backup was previously deleted, will return 200 OK. +// Implements DELETE /_matrix/client/r0/room_keys/version/{version} +func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { + var performKeyBackupResp userapi.PerformKeyBackupResponse + userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + UserID: device.UserID, + Version: version, + DeleteBackup: true, + }, &performKeyBackupResp) + if performKeyBackupResp.Error != "" { + if performKeyBackupResp.BadInput { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue(performKeyBackupResp.Error), + } + } + return util.ErrorResponse(fmt.Errorf("PerformKeyBackup: %s", performKeyBackupResp.Error)) + } + if !performKeyBackupResp.Exists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("backup version not found"), + } + } + // Unclear what the 200 body should be + return util.JSONResponse{ + Code: 200, + JSON: keyBackupVersionCreateResponse{ + Version: performKeyBackupResp.Version, + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d768247a4..194ab2997 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -896,6 +896,48 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + // Key Backup Versions + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + return KeyBackupVersion(req, userAPI, device, version) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version", + httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return KeyBackupVersion(req, userAPI, device, "") + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + return ModifyKeyBackupVersionAuthData(req, userAPI, device, version) + }), + ).Methods(http.MethodPut) + r0mux.Handle("/room_keys/version/{versionID}", + httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + version := req.URL.Query().Get("version") + if version == "" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("version must be specified"), + } + } + return DeleteKeyBackupVersion(req, userAPI, device, version) + }), + ).Methods(http.MethodDelete) + r0mux.Handle("/room_keys/version", + httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return CreateKeyBackupVersion(req, userAPI, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Supplying a device ID is deprecated. r0mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 79aaebc0b..1bbb485cf 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -554,6 +554,10 @@ func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.Quer func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} type testRoomserverAPI struct { // use a trace API as it implements method stubs so we don't need to have them here. diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 96160c10d..2c195d128 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -373,6 +373,10 @@ func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *usera func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { return nil } +func (u *testUserAPI) PerformKeyBackup(ctx context.Context, req *userapi.PerformKeyBackupRequest, res *userapi.PerformKeyBackupResponse) { +} +func (u *testUserAPI) QueryKeyBackup(ctx context.Context, req *userapi.QueryKeyBackupRequest, res *userapi.QueryKeyBackupResponse) { +} func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { diff --git a/sytest-whitelist b/sytest-whitelist index d90ba4fb9..e2bb41da4 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -540,3 +540,4 @@ Key notary server must not overwrite a valid key with a spurious result from the GET /rooms/:room_id/aliases lists aliases Only room members can list aliases of a room Users with sufficient power-level can delete other's aliases +Can create more than 10 backup versions diff --git a/userapi/api/api.go b/userapi/api/api.go index 407350123..ff10bfcd8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -33,6 +33,8 @@ type UserInternalAPI interface { PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error @@ -42,6 +44,37 @@ type UserInternalAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error } +type PerformKeyBackupRequest struct { + UserID string + Version string // optional if modifying a key backup + AuthData json.RawMessage + Algorithm string + DeleteBackup bool // if true will delete the backup based on 'Version'. +} + +type PerformKeyBackupResponse struct { + Error string // set if there was a problem performing the request + BadInput bool // if set, the Error was due to bad input (HTTP 400) + Exists bool // set to true if the Version exists + Version string +} + +type QueryKeyBackupRequest struct { + UserID string + Version string // the version to query, if blank it means the latest +} + +type QueryKeyBackupResponse struct { + Error string + Exists bool + + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` +} + // InputAccountDataRequest is the request for InputAccountData type InputAccountDataRequest struct { UserID string // required: the user to set account data for diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 21933c1c4..9ff69298a 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -442,3 +442,57 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp return nil } + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // Delete + if req.DeleteBackup { + if req.Version == "" { + res.BadInput = true + res.Error = "must specify a version to delete" + return + } + exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to delete backup: %s", err) + } + res.Exists = exists + res.Version = req.Version + return + } + // Create + if req.Version == "" { + version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to create backup: %s", err) + } + res.Exists = err == nil + res.Version = version + return + } + // Update + err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Version = req.Version +} + +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + version, algorithm, authData, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + res.Version = version + if err != nil { + if err == sql.ErrNoRows { + res.Exists = false + return + } + res.Error = fmt.Sprintf("failed to query key backup: %s", err) + return + } + res.Algorithm = algorithm + res.AuthData = authData + res.Exists = !deleted + + // TODO: + res.Count = 0 + res.ETag = "" +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1cb5ef0a8..a89d1a266 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -36,7 +36,9 @@ const ( PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" + PerformKeyBackupPath = "/userapi/performKeyBackup" + QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" @@ -225,3 +227,24 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryOpenIDTokenPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + PerformKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} +func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") + defer span.Finish() + + apiURL := h.apiURL + QueryKeyBackupPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = err.Error() + } +} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 5aa61b909..88fdab481 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -54,6 +54,12 @@ type Database interface { DeactivateAccount(ctx context.Context, localpart string) (err error) CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) + + // Key backups + CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) + DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go new file mode 100644 index 000000000..1b693e565 --- /dev/null +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -0,0 +1,144 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq; + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'), + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + var versionInt int64 + if version == "" { + err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + } else { + versionInt, err = strconv.ParseInt(version, 10, 64) + } + if err != nil { + return + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c5e74ed15..719e9878c 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -45,6 +45,7 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackups keyBackupVersionStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -93,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -368,3 +372,42 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go new file mode 100644 index 000000000..3e85705cf --- /dev/null +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -0,0 +1,142 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strconv" +) + +const keyBackupVersionTableSchema = ` +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( + user_id TEXT NOT NULL, + -- this means no 2 users will ever have the same version of e2e session backups which strictly + -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + version INTEGER PRIMARY KEY AUTOINCREMENT, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted INTEGER DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +` + +const insertKeyBackupSQL = "" + + "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version" + +const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well? + "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +const deleteKeyBackupSQL = "" + + "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +const selectKeyBackupSQL = "" + + "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +const selectLatestVersionSQL = "" + + "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" + +type keyBackupVersionStatements struct { + insertKeyBackupStmt *sql.Stmt + updateKeyBackupAuthDataStmt *sql.Stmt + deleteKeyBackupStmt *sql.Stmt + selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt *sql.Stmt +} + +func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(keyBackupVersionTableSchema) + if err != nil { + return + } + if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { + return + } + if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { + return + } + if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { + return + } + if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { + return + } + if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { + return + } + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + var versionInt int64 + err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt) + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, +) error { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (bool, error) { + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + if err != nil { + return false, err + } + ra, err := result.RowsAffected() + if err != nil { + return false, err + } + return ra == 1, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, txn *sql.Tx, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + var versionInt int64 + if version == "" { + err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt) + } else { + versionInt, err = strconv.ParseInt(version, 10, 64) + } + if err != nil { + return + } + versionResult = strconv.FormatInt(versionInt, 10) + var deletedInt int + var authDataStr string + err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt) + deleted = deletedInt == 1 + authData = json.RawMessage(authDataStr) + return +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index c0f7118cb..99e016ffd 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -43,6 +43,7 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackups keyBackupVersionStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -97,6 +98,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -406,3 +410,42 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + return err + }) + return +}