From 2bc0e52250091e191ed3ef5a939506a77d51ac9e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Jun 2019 21:28:14 +0300 Subject: [PATCH] Fix puppet db inserts. Fixes #69 --- database/message.go | 2 +- database/puppet.go | 2 +- database/upgrades/2019-05-23-protoupgrade.go | 2 +- database/user.go | 9 ++++----- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/database/message.go b/database/message.go index 78a8043..f7976fd 100644 --- a/database/message.go +++ b/database/message.go @@ -41,7 +41,7 @@ func (mq *MessageQuery) New() *Message { } func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { - rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver) + rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver) if err != nil || rows == nil { return nil } diff --git a/database/puppet.go b/database/puppet.go index 8a9cfae..bb35923 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -112,7 +112,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet { } func (puppet *Puppet) Insert() { - _, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + _, err := puppet.db.Exec("INSERT INTO puppet (jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", puppet.JID, puppet.Avatar, puppet.AvatarURL, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch) if err != nil { puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) diff --git a/database/upgrades/2019-05-23-protoupgrade.go b/database/upgrades/2019-05-23-protoupgrade.go index 1e2d72e..4bbc01a 100644 --- a/database/upgrades/2019-05-23-protoupgrade.go +++ b/database/upgrades/2019-05-23-protoupgrade.go @@ -8,7 +8,7 @@ import ( func init() { var keys = []string{"imageMessage", "contactMessage", "locationMessage", "extendedTextMessage", "documentMessage", "audioMessage", "videoMessage"} - upgrades[4] = upgrade{"Update message content to new protocol version", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { + upgrades[4] = upgrade{"Update message content to new protocol version. This may take a while.", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { rows, err := db.Query("SELECT mxid, content FROM message") if err != nil { return err diff --git a/database/user.go b/database/user.go index 300d12e..2d621cc 100644 --- a/database/user.go +++ b/database/user.go @@ -43,7 +43,7 @@ func (uq *UserQuery) New() *User { } func (uq *UserQuery) GetAll() (users []*User) { - rows, err := uq.db.Query(`SELECT * FROM "user"`) + rows, err := uq.db.Query(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user"`) if err != nil || rows == nil { return nil } @@ -55,7 +55,7 @@ func (uq *UserQuery) GetAll() (users []*User) { } func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User { - row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID) + row := uq.db.QueryRow(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user" WHERE mxid=$1`, userID) if row == nil { return nil } @@ -63,7 +63,7 @@ func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User { } func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User { - row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID)) + row := uq.db.QueryRow(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user" WHERE jid=$1`, stripSuffix(userID)) if row == nil { return nil } @@ -84,8 +84,7 @@ type User struct { func (user *User) Scan(row Scannable) *User { var jid, clientID, clientToken, serverToken sql.NullString var encKey, macKey []byte - err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &clientID, &clientToken, &serverToken, &encKey, &macKey, - &user.LastConnection) + err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &user.LastConnection, &clientID, &clientToken, &serverToken, &encKey, &macKey) if err != nil { if err != sql.ErrNoRows { user.log.Errorln("Database scan failed:", err)