mirror of
https://github.com/matrix-org/dendrite
synced 2024-11-20 08:40:14 +01:00
Use returned ID from INSERT in create filter (#297)
This commit is contained in:
parent
f6bda82366
commit
c0271c2462
2 changed files with 5 additions and 24 deletions
|
@ -15,8 +15,8 @@
|
|||
package accounts
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const filterSchema = `
|
||||
|
@ -41,13 +41,9 @@ const selectFilterSQL = "" +
|
|||
const insertFilterSQL = "" +
|
||||
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
|
||||
|
||||
const findMaxIDSQL = "" +
|
||||
"SELECT MAX(id) FROM account_filter WHERE localpart = $1"
|
||||
|
||||
type filterStatements struct {
|
||||
selectFilterStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
findMaxIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||
|
@ -61,10 +57,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
|||
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -77,14 +69,7 @@ func (s *filterStatements) selectFilter(
|
|||
|
||||
func (s *filterStatements) insertFilter(
|
||||
ctx context.Context, filter string, localpart string,
|
||||
) (err error) {
|
||||
_, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *filterStatements) findMaxID(
|
||||
ctx context.Context, localpart string,
|
||||
) (id string, err error) {
|
||||
err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id)
|
||||
) (pos string, err error) {
|
||||
err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
|
||||
// Database represents an account database
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
common.PartitionOffsetStatements
|
||||
accounts accountsStatements
|
||||
profiles profilesStatements
|
||||
|
@ -333,11 +333,7 @@ func (d *Database) GetFilter(
|
|||
func (d *Database) PutFilter(
|
||||
ctx context.Context, localpart, filter string,
|
||||
) (string, error) {
|
||||
err := d.filter.insertFilter(ctx, filter, localpart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return d.filter.findMaxID(ctx, localpart)
|
||||
return d.filter.insertFilter(ctx, filter, localpart)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present in the database.
|
||||
|
|
Loading…
Reference in a new issue