diff --git a/database/historysync.go b/database/historysync.go index 8532c5e..efbc784 100644 --- a/database/historysync.go +++ b/database/historysync.go @@ -20,7 +20,6 @@ import ( "database/sql" "errors" "fmt" - "strconv" "strings" "time" @@ -205,7 +204,7 @@ func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error { const ( getMessagesBetween = ` - SELECT id, data + SELECT data FROM history_sync_message WHERE user_mxid=$1 AND conversation_id=$2 @@ -215,7 +214,7 @@ const ( ` deleteMessages = ` DELETE FROM history_sync_message - WHERE id IN (%s) + WHERE %s ` ) @@ -223,19 +222,14 @@ type HistorySyncMessage struct { db *Database log log.Logger - ID int UserID id.UserID ConversationID string + MessageID string Timestamp time.Time Data []byte } -type WrappedWebMessageInfo struct { - ID int - Message *waProto.WebMessageInfo -} - -func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) { +func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID, messageID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) { msgData, err := proto.Marshal(message) if err != nil { return nil, err @@ -245,6 +239,7 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation log: hsq.log, UserID: userID, ConversationID: conversationID, + MessageID: messageID, Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0), Data: msgData, }, nil @@ -252,15 +247,16 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation func (hsm *HistorySyncMessage) Insert() { _, err := hsm.db.Exec(` - INSERT INTO history_sync_message (user_mxid, conversation_id, timestamp, data) - VALUES ($1, $2, $3, $4) - `, hsm.UserID, hsm.ConversationID, hsm.Timestamp, hsm.Data) + INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING + `, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data) if err != nil { hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err) } } -func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) { +func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) { whereClauses := "" args := []interface{}{userID, conversationID} argNum := 3 @@ -284,10 +280,10 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID if err != nil || rows == nil { return nil } - var msgID int + var msgData []byte for rows.Next() { - err := rows.Scan(&msgID, &msgData) + err := rows.Scan(&msgData) if err != nil { hsq.log.Error("Database scan failed: %v", err) continue @@ -298,21 +294,20 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID hsq.log.Errorf("Failed to unmarshal history sync message: %v", err) continue } - messages = append(messages, &WrappedWebMessageInfo{ - ID: msgID, - Message: historySyncMsg.Message, - }) + messages = append(messages, historySyncMsg.Message) } return } -func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error { - messageIDs := make([]string, len(messages)) +func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error { + whereClauses := []string{} + preparedStatementArgs := []interface{}{userID, conversationID} for i, msg := range messages { - messageIDs[i] = strconv.Itoa(msg.ID) + whereClauses = append(whereClauses, fmt.Sprintf("(user_mxid=$1 AND conversation_id=$2 AND message_id=$%d)", i+3)) + preparedStatementArgs = append(preparedStatementArgs, msg.GetKey().GetId()) } - _, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ","))) + _, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(whereClauses, " OR ")), preparedStatementArgs...) return err } diff --git a/database/upgrades/2022-03-18-historysync-store.go b/database/upgrades/2022-03-18-historysync-store.go index 5597afb..15564d0 100644 --- a/database/upgrades/2022-03-18-historysync-store.go +++ b/database/upgrades/2022-03-18-historysync-store.go @@ -24,7 +24,6 @@ func init() { unread_count INTEGER, PRIMARY KEY (user_mxid, conversation_id), - UNIQUE (conversation_id), FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -34,14 +33,15 @@ func init() { } _, err = tx.Exec(` CREATE TABLE history_sync_message ( - id SERIAL PRIMARY KEY, user_mxid TEXT, conversation_id TEXT, + message_id TEXT, timestamp TIMESTAMP, data BYTEA, + PRIMARY KEY (user_mxid, conversation_id, message_id), FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE + FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE ) `) if err != nil { @@ -65,7 +65,6 @@ func init() { unread_count INTEGER, PRIMARY KEY (user_mxid, conversation_id), - UNIQUE (conversation_id), FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE ) @@ -75,14 +74,15 @@ func init() { } _, err = tx.Exec(` CREATE TABLE history_sync_message ( - id INTEGER PRIMARY KEY, user_mxid TEXT, conversation_id TEXT, + message_id TEXT, timestamp DATETIME, data BLOB, + PRIMARY KEY (user_mxid, conversation_id, message_id), FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE + FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE ) `) if err != nil { diff --git a/historysync.go b/historysync.go index 8800a39..8790b9c 100644 --- a/historysync.go +++ b/historysync.go @@ -134,7 +134,7 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill break } - var msgs []*database.WrappedWebMessageInfo + var msgs []*waProto.WebMessageInfo if len(toBackfill) <= req.MaxBatchEvents { msgs = toBackfill toBackfill = toBackfill[0:0] @@ -152,11 +152,11 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID) if len(insertionEventIds) > 0 { portal.sendPostBackfillDummy( - time.Unix(int64(allMsgs[len(allMsgs)-1].Message.GetMessageTimestamp()), 0), + time.Unix(int64(allMsgs[len(allMsgs)-1].GetMessageTimestamp()), 0), insertionEventIds[0]) } user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs)) - err := user.bridge.DB.HistorySyncQuery.DeleteMessages(allMsgs) + err := user.bridge.DB.HistorySyncQuery.DeleteMessages(user.MXID, conv.ConversationID, allMsgs) if err != nil { user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err) } @@ -227,7 +227,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History continue } - message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg) + message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg.Message.GetKey().GetId(), msg) if err != nil { user.log.Warnf("Failed to save message %s in %s. Error: %+v", msg.Message.Key.Id, conv.GetId(), err) continue @@ -306,11 +306,11 @@ var ( MSC2716Marker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType} ) -func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMessageInfo) []id.EventID { +func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID { var historyBatch, newBatch mautrix.ReqBatchSend var historyBatchInfos, newBatchInfos []*wrappedInfo - firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].Message.GetMessageTimestamp()), 0) + firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0) historyBatch.StateEventsAtStart = make([]*event.Event, 0) newBatch.StateEventsAtStart = make([]*event.Event, 0) @@ -365,7 +365,7 @@ func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMess portal.log.Infofln("Processing history sync with %d messages", len(messages)) // The messages are ordered newest to oldest, so iterate them in reverse order. for i := len(messages) - 1; i >= 0; i-- { - webMsg := messages[i].Message + webMsg := messages[i] msgType := getMessageType(webMsg.GetMessage()) if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" { if msgType != "ignore" {