diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 8ba890e75..67113367b 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,19 +57,20 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, rows.Err() } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 9e3e456a7..92311d56d 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -78,14 +79,15 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( +func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString @@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 64cc0b71a..7bc5dc69b 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,53 +112,32 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if err = s.execSchema(db); err != nil { - return + if err != nil { + return nil, err } - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.deleteDevicesStmt, deleteDevicesSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice( } // deleteDevice removes a single device by id and user localpart. -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) if err != nil { return nil, err @@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index c1402d4d2..ac0e80617 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go index d73447b49..e78e4cd51 100644 --- a/userapi/storage/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -69,12 +70,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go index 508a68989..4de96f839 100644 --- a/userapi/storage/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -21,18 +21,11 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) -type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt -} - -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,24 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err +` + +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - if err := s.execSchema(db); err != nil { - return err +func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err } - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 190d141b7..29c3ddcb4 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return +func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 9313864be..32a4b5506 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -59,12 +60,13 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(profilesSchema) +func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{} + _, err := db.Exec(profilesSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, +func (s *profilesStatements) SetAvatarURL( + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) return } -func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, +func (s *profilesStatements) SetDisplayName( + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 734192798..ac5c59b81 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -15,76 +15,33 @@ package postgres import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" "fmt" - "strconv" "time" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/shared" // Import the postgres database driver. _ "github.com/lib/pq" ) -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 -} - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 -) - // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - loginTokenLifetime: loginTokenLifetime, - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet return nil, err } - m := sqlutil.NewMigrations() deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) @@ -92,638 +49,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } - if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - if err = d.devices.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.loginTokens.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + accountDataTable, err := NewPostgresAccountDataTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) + accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { - return err + return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } - return d.accounts.updatePassword(ctx, localpart, hash) -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, sqlutil.ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - // For guest accounts, we create a new numeric local part - if accountType == api.AccountTypeGuest { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart = strconv.FormatInt(numLocalpart, 10) - plaintextPassword = "" - appserviceID = "" - } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) - return err - }) - return -} - -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (*api.Account, error) { - var account *api.Account - var err error - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { - if sqlutil.IsUniqueConstraintViolationErr(err) { - return nil, sqlutil.ErrUserExists - } - return nil, err - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) -} - -// CreateOpenIDToken persists a new token that was issued through OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - return nil - }) - return -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) + devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { - return "", err + return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() + keyBackupTable, err := NewPostgresKeyBackupTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err) } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) + keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err) } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) + loginTokenTable, err := NewPostgresLoginTokenTable(db) if err != nil { - return "", err + return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err) } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) + openIDTable, err := NewPostgresOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err) + } + profilesTable, err := NewPostgresProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err) + } + threePIDTable, err := NewPostgresThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewDummyWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil } diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 9280fc87c..63c08d61f 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -58,12 +59,13 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(threepidSchema) +func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{} + _, err := db.Exec(threepidSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID( return } -func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) +func (s *threepidStatements) DeleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go new file mode 100644 index 000000000..5f1f95005 --- /dev/null +++ b/userapi/storage/shared/storage.go @@ -0,0 +1,672 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +// Database represents an account database +type Database struct { + DB *sql.DB + Writer sqlutil.Writer + Accounts tables.AccountsTable + Profiles tables.ProfileTable + AccountDatas tables.AccountDataTable + ThreePIDs tables.ThreePIDTable + OpenIDTokens tables.OpenIDTable + KeyBackups tables.KeyBackupTable + KeyBackupVersions tables.KeyBackupVersionTable + Devices tables.DevicesTable + LoginTokens tables.LoginTokenTable + LoginTokenLifetime time.Duration + ServerName gomatrixserverlib.ServerName + BcryptCost int + OpenIDTokenLifetimeMS int64 +} + +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*api.Account, error) { + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.Profiles.SelectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + }) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + }) +} + +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := d.hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePassword(ctx, localpart, hash) + }) +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, ErrUserExists. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, +) (acc *api.Account, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + return err + }) + return +} + +// WARNING! This function assumes that the relevant mutexes have already +// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, +) (*api.Account, error) { + var err error + var account *api.Account + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = d.hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + return nil, sqlutil.ErrUserExists + } + if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`)); err != nil { + return nil, err + } + return account, nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, + err error, +) { + return d.AccountDatas.SelectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data json.RawMessage, err error) { + return d.AccountDatas.SelectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.Accounts.SelectNewNumericLocalpart(ctx, nil) +} + +func (d *Database) hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("this third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + user, err := d.ThreePIDs.SelectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium) + }) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*api.Account, error) { + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) +} + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit) +} + +// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. +func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.DeactivateAccount(ctx, localpart) + }) +} + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token) +} + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err) + } + } + + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + + return nil + }) + return +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.Devices.SelectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.LoginTokenLifetime), + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.DeleteLoginToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.LoginTokens.SelectLoginToken(ctx, token) +} diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index 871f996e0..cfd8568a9 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,27 +57,29 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(accountDataSchema) - if err != nil { - return +func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(accountDataSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return err } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, nil } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 5a918e034..e6c37e58e 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -77,15 +78,16 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( +func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString @@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 119ecdf93..423640e90 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" @@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -98,55 +98,33 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.writer = writer - if err = s.execSchema(db); err != nil { - return + if err != nil { + return nil, err } - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDevicesCountStmt, selectDevicesCountSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice( }, nil } -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice( return err } -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) @@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices( return err } -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) iDeviceIDs := make([]interface{}, len(deviceIDs)) for i := range deviceIDs { @@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 837d38cf1..81726edf9 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go index 4211ed0f1..e85e6f08b 100644 --- a/userapi/storage/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -67,12 +68,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go index 52322b46a..78d42029a 100644 --- a/userapi/storage/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -21,18 +21,17 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,24 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err -} +` -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - if err := s.execSchema(db); err != nil { - return err +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err } - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 98c0488b1..d6090e0da 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return err +func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + db: db, + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index a92e95663..d85b19c7b 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -60,13 +61,15 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(profilesSchema) - if err != nil { - return +func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(profilesSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return err } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( +func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) @@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL( return } -func (s *profilesStatements) setDisplayName( +func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) @@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName( return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 56ec1b6af..98c244977 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,80 +15,34 @@ package sqlite3 import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" "fmt" - "strconv" - "sync" "time" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" -) -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 - - accountsMu sync.Mutex - profilesMu sync.Mutex - accountDatasMu sync.Mutex - threepidsMu sync.Mutex -} - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 + // Import the postgres database driver. + _ "github.com/lib/pq" ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - loginTokenLifetime: loginTokenLifetime, - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet return nil, err } - m := sqlutil.NewMigrations() deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) @@ -96,666 +50,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - if err = d.devices.prepare(db, d.writer, serverName); err != nil { - return nil, err - } - if err = d.loginTokens.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + accountDataTable, err := NewSQLiteAccountDataTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) - }) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setDisplayName(ctx, txn, localpart, displayName) - }) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) + accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { - return err + return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.updatePassword(ctx, localpart, hash) - }) -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (acc *api.Account, err error) { - // Create one account at a time else we can get 'database is locked'. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - // For guest accounts, we create a new numeric local part - if accountType == api.AccountTypeGuest { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart = strconv.FormatInt(numLocalpart, 10) - plaintextPassword = "" - appserviceID = "" - } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) - return err - }) - return -} - -// WARNING! This function assumes that the relevant mutexes have already -// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (*api.Account, error) { - var err error - var account *api.Account - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { - return nil, sqlutil.ErrUserExists - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - d.accountDatasMu.Lock() - defer d.accountDatasMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.threepids.deleteThreePID(ctx, txn, threepid, medium) - }) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.deactivateAccount(ctx, localpart) - }) -} - -// CreateOpenIDToken persists a new token that was issued for OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - - return nil - }) - return -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) + devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { - return "", err + return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() + keyBackupTable, err := NewSQLiteKeyBackupTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err) } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) + keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err) } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) + loginTokenTable, err := NewSQLiteLoginTokenTable(db) if err != nil { - return "", err + return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err) } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) + openIDTable, err := NewSQLiteOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err) + } + profilesTable, err := NewSQLiteProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err) + } + threePIDTable, err := NewSQLiteThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil } diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 9dc0e2d22..fa174eed5 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -60,13 +61,15 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(threepidSchema) - if err != nil { - return +func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(threepidSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return threepids, rows.Err() } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID( return err } -func (s *threepidStatements) deleteThreePID( +func (s *threepidStatements) DeleteThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go new file mode 100644 index 000000000..12939ced5 --- /dev/null +++ b/userapi/storage/tables/interface.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" +) + +type AccountDataTable interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) +} + +type AccountsTable interface { + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string) (err error) + SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type DevicesTable interface { + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) + SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error +} + +type KeyBackupTable interface { + CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error) + InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error) +} + +type KeyBackupVersionTable interface { + InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error + UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error + DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error) + SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) +} + +type LoginTokenTable interface { + InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error + DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error + SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} + +type OpenIDTable interface { + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) +} + +type ProfileTable interface { + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error + SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) +} + +type ThreePIDTable interface { + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) +}