diff --git a/database/message.go b/database/message.go index 254c9d3..1b6bbbf 100644 --- a/database/message.go +++ b/database/message.go @@ -43,7 +43,7 @@ func (mq *MessageQuery) New() *Message { } func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { - rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver) + rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver) if err != nil || rows == nil { return nil } @@ -55,18 +55,19 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { } func (mq *MessageQuery) GetByJID(chat PortalKey, jid whatsapp.MessageID) *Message { - return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+ + return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+ "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid) } func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message { - return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+ + return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+ "FROM message WHERE mxid=$1", mxid) } func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message { - msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content "+ - "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver) + msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content "+ + "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp DESC LIMIT 1", + chat.JID, chat.Receiver) if msg == nil || msg.Timestamp == 0 { // Old db, we don't know what the last message is. return nil @@ -91,6 +92,7 @@ type Message struct { MXID id.EventID Sender whatsapp.JID Timestamp uint64 + Sent bool Content *waProto.Message } @@ -100,7 +102,7 @@ func (msg *Message) IsFakeMXID() bool { func (msg *Message) Scan(row Scannable) *Message { var content []byte - err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &content) + err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &msg.Sent, &content) if err != nil { if err != sql.ErrNoRows { msg.log.Errorln("Database scan failed:", err) @@ -134,14 +136,23 @@ func (msg *Message) encodeBinaryContent() []byte { } func (msg *Message) Insert() { - _, err := msg.db.Exec("INSERT INTO message (chat_jid, chat_receiver, jid, mxid, sender, timestamp, content) "+ - "VALUES ($1, $2, $3, $4, $5, $6, $7)", - msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.Timestamp, msg.encodeBinaryContent()) + _, err := msg.db.Exec(`INSERT INTO message + (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, content) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.Timestamp, msg.Sent, msg.encodeBinaryContent()) if err != nil { msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) } } +func (msg *Message) MarkSent() { + msg.Sent = true + _, err := msg.db.Exec("UPDATE message SET sent=true WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID) + if err != nil { + msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err) + } +} + func (msg *Message) Delete() { _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID) if err != nil { diff --git a/database/upgrades/2021-02-17-message-sent-status.go b/database/upgrades/2021-02-17-message-sent-status.go new file mode 100644 index 0000000..a5852b0 --- /dev/null +++ b/database/upgrades/2021-02-17-message-sent-status.go @@ -0,0 +1,12 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error { + _, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`) + return err + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index e224593..af54f6c 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -39,7 +39,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 20 +const NumberOfUpgrades = 21 var upgrades [NumberOfUpgrades]upgrade diff --git a/portal.go b/portal.go index cc0dc29..d1703ab 100644 --- a/portal.go +++ b/portal.go @@ -283,7 +283,7 @@ func init() { gob.Register(&waProto.Message{}) } -func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID) { +func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID, isSent bool) *database.Message { msg := portal.bridge.DB.Message.New() msg.Chat = portal.Key msg.JID = message.GetKey().GetId() @@ -300,6 +300,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, } } msg.Content = message.Message + msg.Sent = isSent msg.Insert() portal.recentlyHandledLock.Lock() @@ -307,6 +308,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength portal.recentlyHandledLock.Unlock() portal.recentlyHandled[index] = msg.JID + return msg } func (portal *Portal) getMessageIntent(user *User, info whatsapp.MessageInfo) *appservice.IntentAPI { @@ -346,7 +348,7 @@ func (portal *Portal) startHandling(source *User, info whatsapp.MessageInfo) *ap } func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid id.EventID) { - portal.markHandled(source, message, mxid) + portal.markHandled(source, message, mxid, true) portal.sendDeliveryReceipt(mxid) portal.log.Debugln("Handled message", message.GetKey().GetId(), "->", mxid) } @@ -735,6 +737,10 @@ func (portal *Portal) BackfillHistory(user *User, lastMessageTime uint64) error for len(lastMessageID) > 0 { portal.log.Debugln("Fetching 50 messages of history after", lastMessageID) resp, err := user.Conn.LoadMessagesAfter(portal.Key.JID, lastMessageID, lastMessageFromMe, 50) + if err == whatsapp.ErrServerRespondedWith404 { + portal.log.Warnln("Got 404 response trying to fetch messages to backfill. Fetching latest messages as fallback.") + resp, err = user.Conn.LoadMessagesBefore(portal.Key.JID, "", true, 50) + } if err != nil { return err } @@ -1322,7 +1328,7 @@ func (portal *Portal) HandleStubMessage(source *User, message whatsapp.StubMessa if len(eventID) == 0 { eventID = id.EventID(fmt.Sprintf("net.maunium.whatsapp.fake::%s", message.Info.Id)) } - portal.markHandled(source, message.Info.Source, eventID) + portal.markHandled(source, message.Info.Source, eventID, true) } func (portal *Portal) HandleLocationMessage(source *User, message whatsapp.LocationMessage) { @@ -2087,12 +2093,12 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) { if info == nil { return } - portal.markHandled(sender, info, evt.ID) + dbMsg := portal.markHandled(sender, info, evt.ID, false) portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.Key.GetId()) - portal.sendRaw(sender, evt, info) + portal.sendRaw(sender, evt, info, dbMsg) } -func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebMessageInfo) { +func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebMessageInfo, dbMsg *database.Message) { errChan := make(chan error, 1) go sender.Conn.SendRaw(info, errChan) @@ -2112,16 +2118,11 @@ func (portal *Portal) sendRaw(sender *User, evt *event.Event, info *waProto.WebM } if err != nil { portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err) - var statusResp whatsapp.StatusResponse - if errors.As(err, &statusResp) && statusResp.Status == 599 { - portal.log.Debugfln("599 status response extra data: %+v", statusResp.Extra) - portal.sendErrorMessage(fmt.Sprintf("%v. Please try again after a few minutes", err)) - } else { - portal.sendErrorMessage(err.Error()) - } + portal.sendErrorMessage(err.Error()) } else { portal.log.Debugfln("Handled Matrix event %s", evt.ID) portal.sendDeliveryReceipt(evt.ID) + dbMsg.MarkSent() } if errorEventID != "" { _, err = portal.MainIntent().RedactEvent(portal.MXID, errorEventID)