From 89a6787fdb60a3b4e7054726ebe0851128e9e16a Mon Sep 17 00:00:00 2001 From: S7evinK Date: Mon, 7 Jun 2021 10:17:46 +0200 Subject: [PATCH] Try to optimize SelectOneTimeKeys (#1851) * Try to optimize SelectOneTimeKeys Signed-off-by: Till Faelligen * Use pg.Array when using ANY... Co-authored-by: Kegsay --- .../storage/postgres/one_time_keys_table.go | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index 6e32838b1..cc397ba84 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "time" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" @@ -47,7 +48,7 @@ const upsertKeysSQL = "" + " DO UPDATE SET key_json = $6" const selectKeysSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" const selectKeysCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" @@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) - for _, ka := range keyIDsWithAlgorithms { - wantSet[ka] = true - } - result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) for rows.Next() { - var keyID string - var algorithm string - var keyJSONStr string - if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { return nil, err } - keyIDWithAlgo := algorithm + ":" + keyID - if wantSet[keyIDWithAlgo] { - result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) - } + result[algorithmWithID] = json.RawMessage(keyJSONStr) } return result, rows.Err() }