diff --git a/database/historysync.go b/database/historysync.go index 89fefb6..8532c5e 100644 --- a/database/historysync.go +++ b/database/historysync.go @@ -20,6 +20,8 @@ import ( "database/sql" "errors" "fmt" + "strconv" + "strings" "time" waProto "go.mau.fi/whatsmeow/binary/proto" @@ -203,7 +205,7 @@ func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error { const ( getMessagesBetween = ` - SELECT data + SELECT id, data FROM history_sync_message WHERE user_mxid=$1 AND conversation_id=$2 @@ -211,18 +213,28 @@ const ( ORDER BY timestamp DESC %s ` + deleteMessages = ` + DELETE FROM history_sync_message + WHERE id IN (%s) + ` ) type HistorySyncMessage struct { db *Database log log.Logger + ID int UserID id.UserID ConversationID string Timestamp time.Time Data []byte } +type WrappedWebMessageInfo struct { + ID int + Message *waProto.WebMessageInfo +} + func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) { msgData, err := proto.Marshal(message) if err != nil { @@ -248,7 +260,7 @@ func (hsm *HistorySyncMessage) Insert() { } } -func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) { +func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) { whereClauses := "" args := []interface{}{userID, conversationID} argNum := 3 @@ -272,9 +284,10 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID if err != nil || rows == nil { return nil } + var msgID int var msgData []byte for rows.Next() { - err := rows.Scan(&msgData) + err := rows.Scan(&msgID, &msgData) if err != nil { hsq.log.Error("Database scan failed: %v", err) continue @@ -285,11 +298,24 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID hsq.log.Errorf("Failed to unmarshal history sync message: %v", err) continue } - messages = append(messages, historySyncMsg.Message) + messages = append(messages, &WrappedWebMessageInfo{ + ID: msgID, + Message: historySyncMsg.Message, + }) } return } +func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error { + messageIDs := make([]string, len(messages)) + for i, msg := range messages { + messageIDs[i] = strconv.Itoa(msg.ID) + } + + _, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ","))) + return err +} + func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) error { _, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID) return err diff --git a/database/message.go b/database/message.go index e782d9f..ad90cba 100644 --- a/database/message.go +++ b/database/message.go @@ -143,6 +143,7 @@ type Message struct { db *Database log log.Logger + ID int Chat PortalKey JID types.MessageID MXID id.EventID diff --git a/database/upgrades/2022-03-18-historysync-store.go b/database/upgrades/2022-03-18-historysync-store.go index bf53ab1..5597afb 100644 --- a/database/upgrades/2022-03-18-historysync-store.go +++ b/database/upgrades/2022-03-18-historysync-store.go @@ -34,6 +34,7 @@ func init() { } _, err = tx.Exec(` CREATE TABLE history_sync_message ( + id SERIAL PRIMARY KEY, user_mxid TEXT, conversation_id TEXT, timestamp TIMESTAMP, @@ -74,6 +75,7 @@ func init() { } _, err = tx.Exec(` CREATE TABLE history_sync_message ( + id INTEGER PRIMARY KEY, user_mxid TEXT, conversation_id TEXT, timestamp DATETIME, diff --git a/historysync.go b/historysync.go index 553dfff..cea3203 100644 --- a/historysync.go +++ b/historysync.go @@ -126,7 +126,7 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac break } - var msgs []*waProto.WebMessageInfo + var msgs []*database.WrappedWebMessageInfo if len(toBackfill) <= req.MaxBatchEvents { msgs = toBackfill toBackfill = toBackfill[0:0] @@ -144,9 +144,14 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID) if len(insertionEventIds) > 0 { portal.sendPostBackfillDummy( - time.Unix(int64(allMsgs[len(allMsgs)-1].GetMessageTimestamp()), 0), + time.Unix(int64(allMsgs[len(allMsgs)-1].Message.GetMessageTimestamp()), 0), insertionEventIds[0]) } + user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs)) + err := user.bridge.DB.HistorySyncQuery.DeleteMessages(allMsgs) + if err != nil { + user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err) + } } else { user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID) } @@ -288,14 +293,14 @@ var ( MSC2716Marker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType} ) -func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID { +func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMessageInfo) []id.EventID { portal.backfillLock.Lock() defer portal.backfillLock.Unlock() var historyBatch, newBatch mautrix.ReqBatchSend var historyBatchInfos, newBatchInfos []*wrappedInfo - firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0) + firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].Message.GetMessageTimestamp()), 0) historyBatch.StateEventsAtStart = make([]*event.Event, 0) newBatch.StateEventsAtStart = make([]*event.Event, 0) @@ -350,7 +355,7 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) portal.log.Infofln("Processing history sync with %d messages", len(messages)) // The messages are ordered newest to oldest, so iterate them in reverse order. for i := len(messages) - 1; i >= 0; i-- { - webMsg := messages[i] + webMsg := messages[i].Message msgType := getMessageType(webMsg.GetMessage()) if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" { if msgType != "ignore" {