mirror of
https://github.com/tulir/mautrix-whatsapp
synced 2024-12-13 17:13:11 +01:00
Fix some bugs with db crypto store
This commit is contained in:
parent
dfc5722a80
commit
ea23907492
4 changed files with 39 additions and 9 deletions
|
@ -126,7 +126,7 @@ func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
|
||||||
ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
|
ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
|
||||||
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
||||||
} else if store.db.dialect == "sqlite3" {
|
} else if store.db.dialect == "sqlite3" {
|
||||||
_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
|
_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
|
||||||
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
||||||
|
@ -270,11 +270,11 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
|
||||||
var resultEventID id.EventID
|
var resultEventID id.EventID
|
||||||
var resultTimestamp int64
|
var resultTimestamp int64
|
||||||
err := store.db.QueryRow(
|
err := store.db.QueryRow(
|
||||||
"SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3",
|
`SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
|
||||||
senderKey, sessionID, index,
|
senderKey, sessionID, index,
|
||||||
).Scan(&resultEventID, &resultTimestamp)
|
).Scan(&resultEventID, &resultTimestamp)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
_, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)",
|
_, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
|
||||||
senderKey, sessionID, index, eventID, timestamp)
|
senderKey, sessionID, index, eventID, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
store.log.Warnln("Failed to store message index:", err)
|
store.log.Warnln("Failed to store message index:", err)
|
||||||
|
@ -325,7 +325,7 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI
|
||||||
if store.db.dialect == "postgres" {
|
if store.db.dialect == "postgres" {
|
||||||
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||||
} else if store.db.dialect == "sqlite3" {
|
} else if store.db.dialect == "sqlite3" {
|
||||||
_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID)
|
_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
||||||
}
|
}
|
||||||
|
@ -374,7 +374,13 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
||||||
if store.db.dialect == "postgres" {
|
if store.db.dialect == "postgres" {
|
||||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
|
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
|
||||||
} else {
|
} else {
|
||||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users)
|
queryString := make([]string, len(users))
|
||||||
|
params := make([]interface{}, len(users))
|
||||||
|
for i, user := range users {
|
||||||
|
queryString[i] = fmt.Sprintf("$%d", i+1)
|
||||||
|
params[i] = user
|
||||||
|
}
|
||||||
|
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
store.log.Warnln("Failed to filter tracked users:", err)
|
store.log.Warnln("Failed to filter tracked users:", err)
|
||||||
|
|
|
@ -89,7 +89,7 @@ func migrateTable(old *Database, new *Database, table string, columns ...string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Migrate(old *Database, new *Database) {
|
func Migrate(old *Database, new *Database) {
|
||||||
err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url")
|
err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -121,4 +121,28 @@ func Migrate(old *Database, new *Database) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_tracked_user", "user_id")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,11 +19,11 @@ func init() {
|
||||||
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
|
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
|
||||||
sender_key CHAR(43),
|
sender_key CHAR(43),
|
||||||
session_id VARCHAR(255),
|
session_id VARCHAR(255),
|
||||||
index INTEGER,
|
"index" INTEGER,
|
||||||
event_id VARCHAR(255) NOT NULL,
|
event_id VARCHAR(255) NOT NULL,
|
||||||
timestamp BIGINT NOT NULL,
|
timestamp BIGINT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY (sender_key, session_id, index)
|
PRIMARY KEY (sender_key, session_id, "index")
|
||||||
)`)
|
)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -217,7 +217,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
|
||||||
|
|
||||||
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
|
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mx.log.Warnln("Failed to decrypt %s: %v", evt.ID, err)
|
mx.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mx.bridge.EventProcessor.Dispatch(decrypted)
|
mx.bridge.EventProcessor.Dispatch(decrypted)
|
||||||
|
|
Loading…
Reference in a new issue