diff --git a/src/github.com/matrix-org/dendrite/common/keydb/keydb.go b/src/github.com/matrix-org/dendrite/common/keydb/keydb.go index 8d5a6ddee..51444ab29 100644 --- a/src/github.com/matrix-org/dendrite/common/keydb/keydb.go +++ b/src/github.com/matrix-org/dendrite/common/keydb/keydb.go @@ -49,7 +49,7 @@ func (d *Database) FetchKeys( ctx context.Context, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { - return d.statements.bulkSelectServerKeys(requests) + return d.statements.bulkSelectServerKeys(ctx, requests) } // StoreKeys implements gomatrixserverlib.KeyDatabase @@ -62,7 +62,7 @@ func (d *Database) StoreKeys( // high for a single insert statement. var lastErr error for request, keys := range keyMap { - if err := d.statements.upsertServerKeys(request, keys); err != nil { + if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil { // Rather than returning immediately on error we try to insert the // remaining keys. // Since we are inserting the keys outside of a transaction it is diff --git a/src/github.com/matrix-org/dendrite/common/keydb/server_key_table.go b/src/github.com/matrix-org/dendrite/common/keydb/server_key_table.go index e89ebcda3..7d9455c12 100644 --- a/src/github.com/matrix-org/dendrite/common/keydb/server_key_table.go +++ b/src/github.com/matrix-org/dendrite/common/keydb/server_key_table.go @@ -15,6 +15,7 @@ package keydb import ( + "context" "database/sql" "encoding/json" @@ -73,13 +74,15 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { } func (s *serverKeyStatements) bulkSelectServerKeys( + ctx context.Context, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { var nameAndKeyIDs []string for request := range requests { nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs)) + stmt := s.bulkSelectServerKeysStmt + rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs)) if err != nil { return nil, err } @@ -106,15 +109,21 @@ func (s *serverKeyStatements) bulkSelectServerKeys( } func (s *serverKeyStatements) upsertServerKeys( - request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys, + ctx context.Context, + request gomatrixserverlib.PublicKeyRequest, + keys gomatrixserverlib.ServerKeys, ) error { keyJSON, err := json.Marshal(keys) if err != nil { return err } - _, err = s.upsertServerKeysStmt.Exec( - string(request.ServerName), string(request.KeyID), nameAndKeyID(request), - int64(keys.ValidUntilTS), keyJSON, + _, err = s.upsertServerKeysStmt.ExecContext( + ctx, + string(request.ServerName), + string(request.KeyID), + nameAndKeyID(request), + int64(keys.ValidUntilTS), + keyJSON, ) return err }