From ea239074921c5c708a9fb52a186587d457ad8226 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 May 2020 20:23:30 +0300 Subject: [PATCH] Fix some bugs with db crypto store --- database/cryptostore.go | 16 ++++++++---- database/migrate.go | 26 +++++++++++++++++++- database/upgrades/2020-05-09-crypto-store.go | 4 +-- matrix.go | 2 +- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/database/cryptostore.go b/database/cryptostore.go index 8b36216..0f780b2 100644 --- a/database/cryptostore.go +++ b/database/cryptostore.go @@ -126,7 +126,7 @@ func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error { ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`, store.DeviceID, account.Shared, store.SyncToken, bytes) } else if store.db.dialect == "sqlite3" { - _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)", + _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)", store.DeviceID, account.Shared, store.SyncToken, bytes) } else { err = fmt.Errorf("unsupported dialect %s", store.db.dialect) @@ -270,11 +270,11 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio var resultEventID id.EventID var resultTimestamp int64 err := store.db.QueryRow( - "SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3", + `SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`, senderKey, sessionID, index, ).Scan(&resultEventID, &resultTimestamp) if err == sql.ErrNoRows { - _, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)", + _, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`, senderKey, sessionID, index, eventID, timestamp) if err != nil { store.log.Warnln("Failed to store message index:", err) @@ -325,7 +325,7 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI if store.db.dialect == "postgres" { _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) } else if store.db.dialect == "sqlite3" { - _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID) + _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID) } else { err = fmt.Errorf("unsupported dialect %s", store.db.dialect) } @@ -374,7 +374,13 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID { if store.db.dialect == "postgres" { rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users)) } else { - rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users) + queryString := make([]string, len(users)) + params := make([]interface{}, len(users)) + for i, user := range users { + queryString[i] = fmt.Sprintf("$%d", i+1) + params[i] = user + } + rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...) } if err != nil { store.log.Warnln("Failed to filter tracked users:", err) diff --git a/database/migrate.go b/database/migrate.go index b3cf4e0..9d30871 100644 --- a/database/migrate.go +++ b/database/migrate.go @@ -89,7 +89,7 @@ func migrateTable(old *Database, new *Database, table string, columns ...string) } func Migrate(old *Database, new *Database) { - err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url") + err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted") if err != nil { panic(err) } @@ -121,4 +121,28 @@ func Migrate(old *Database, new *Database) { if err != nil { panic(err) } + err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account") + if err != nil { + panic(err) + } + err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp") + if err != nil { + panic(err) + } + err = migrateTable(old, new, "crypto_tracked_user", "user_id") + if err != nil { + panic(err) + } + err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name") + if err != nil { + panic(err) + } + err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used") + if err != nil { + panic(err) + } + err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains") + if err != nil { + panic(err) + } } diff --git a/database/upgrades/2020-05-09-crypto-store.go b/database/upgrades/2020-05-09-crypto-store.go index 529ff9c..ea454dc 100644 --- a/database/upgrades/2020-05-09-crypto-store.go +++ b/database/upgrades/2020-05-09-crypto-store.go @@ -19,11 +19,11 @@ func init() { _, err = tx.Exec(`CREATE TABLE crypto_message_index ( sender_key CHAR(43), session_id VARCHAR(255), - index INTEGER, + "index" INTEGER, event_id VARCHAR(255) NOT NULL, timestamp BIGINT NOT NULL, - PRIMARY KEY (sender_key, session_id, index) + PRIMARY KEY (sender_key, session_id, "index") )`) if err != nil { return err diff --git a/matrix.go b/matrix.go index d064f49..3e524f4 100644 --- a/matrix.go +++ b/matrix.go @@ -217,7 +217,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { decrypted, err := mx.bridge.Crypto.Decrypt(evt) if err != nil { - mx.log.Warnln("Failed to decrypt %s: %v", evt.ID, err) + mx.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err) return } mx.bridge.EventProcessor.Dispatch(decrypted)