diff --git a/database/upgrades/2018-09-01-initial-schema.go b/database/upgrades/2018-09-01-initial-schema.go index 189d3d0..3e0c7c6 100644 --- a/database/upgrades/2018-09-01-initial-schema.go +++ b/database/upgrades/2018-09-01-initial-schema.go @@ -6,7 +6,7 @@ import ( ) func init() { - upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx) error { + upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { var byteType string if dialect == SQLite { byteType = "BLOB" diff --git a/database/upgrades/2019-05-16-message-delete-cascade.go b/database/upgrades/2019-05-16-message-delete-cascade.go index cb31cd4..131e086 100644 --- a/database/upgrades/2019-05-16-message-delete-cascade.go +++ b/database/upgrades/2019-05-16-message-delete-cascade.go @@ -5,7 +5,7 @@ import ( ) func init() { - upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx) error { + upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { if dialect == SQLite { // SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway. return nil diff --git a/database/upgrades/2019-05-21-message-timestamp-column.go b/database/upgrades/2019-05-21-message-timestamp-column.go index b53a340..ebd7e38 100644 --- a/database/upgrades/2019-05-21-message-timestamp-column.go +++ b/database/upgrades/2019-05-21-message-timestamp-column.go @@ -5,7 +5,7 @@ import ( ) func init() { - upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx) error { + upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { _, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0") if err != nil { return err diff --git a/database/upgrades/2019-05-22-user-last-connection-column.go b/database/upgrades/2019-05-22-user-last-connection-column.go index e53a3b8..db1a214 100644 --- a/database/upgrades/2019-05-22-user-last-connection-column.go +++ b/database/upgrades/2019-05-22-user-last-connection-column.go @@ -5,7 +5,7 @@ import ( ) func init() { - upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx) error { + upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`) if err != nil { return err diff --git a/database/upgrades/2019-05-23-protoupgrade.go b/database/upgrades/2019-05-23-protoupgrade.go new file mode 100644 index 0000000..1e2d72e --- /dev/null +++ b/database/upgrades/2019-05-23-protoupgrade.go @@ -0,0 +1,61 @@ +package upgrades + +import ( + "database/sql" + "encoding/json" + "fmt" +) + +func init() { + var keys = []string{"imageMessage", "contactMessage", "locationMessage", "extendedTextMessage", "documentMessage", "audioMessage", "videoMessage"} + upgrades[4] = upgrade{"Update message content to new protocol version", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { + rows, err := db.Query("SELECT mxid, content FROM message") + if err != nil { + return err + } + for rows.Next() { + var mxid string + var rawContent []byte + err = rows.Scan(&mxid, &rawContent) + if err != nil { + fmt.Println("Error scanning:", err) + continue + } + var content map[string]interface{} + err = json.Unmarshal(rawContent, &content) + if err != nil { + fmt.Printf("Error unmarshaling content of %s: %v\n", mxid, err) + continue + } + + for _, key := range keys { + val, ok := content[key].(map[string]interface{}) + if !ok { + continue + } + ci, ok := val["contextInfo"].(map[string]interface{}) + if !ok { + continue + } + qm, ok := ci["quotedMessage"].([]interface{}) + if !ok { + continue + } + ci["quotedMessage"] = qm[0] + goto save + } + continue + + save: + rawContent, err = json.Marshal(&content) + if err != nil { + fmt.Printf("Error marshaling updated content of %s: %v\n", mxid, err) + } + _, err = tx.Exec("UPDATE message SET content=$1 WHERE mxid=$2", rawContent, mxid) + if err != nil { + fmt.Printf("Error updating row of %s: %v\n", mxid, err) + } + } + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index cbe3d76..5e98872 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -15,14 +15,14 @@ const ( SQLite ) -type upgradeFunc func(Dialect, *sql.Tx) error +type upgradeFunc func(Dialect, *sql.Tx, *sql.DB) error type upgrade struct { message string fn upgradeFunc } -var upgrades [4]upgrade +var upgrades [5]upgrade func getVersion(dialect Dialect, db *sql.DB) (int, error) { _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)") @@ -70,7 +70,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error { if err != nil { return err } - err = upgrade.fn(dialect, tx) + err = upgrade.fn(dialect, tx, db) if err != nil { return err }