historysync: use userID, conversationID, messageID as PK

This commit is contained in:
Sumner Evans 2022-04-05 13:13:20 -06:00
parent 005fbb09f8
commit eb0a13a753
No known key found for this signature in database
GPG key ID: 8904527AB50022FD
3 changed files with 32 additions and 37 deletions

View file

@ -20,7 +20,6 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
@ -205,7 +204,7 @@ func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error {
const ( const (
getMessagesBetween = ` getMessagesBetween = `
SELECT id, data SELECT data
FROM history_sync_message FROM history_sync_message
WHERE user_mxid=$1 WHERE user_mxid=$1
AND conversation_id=$2 AND conversation_id=$2
@ -215,7 +214,7 @@ const (
` `
deleteMessages = ` deleteMessages = `
DELETE FROM history_sync_message DELETE FROM history_sync_message
WHERE id IN (%s) WHERE %s
` `
) )
@ -223,19 +222,14 @@ type HistorySyncMessage struct {
db *Database db *Database
log log.Logger log log.Logger
ID int
UserID id.UserID UserID id.UserID
ConversationID string ConversationID string
MessageID string
Timestamp time.Time Timestamp time.Time
Data []byte Data []byte
} }
type WrappedWebMessageInfo struct { func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID, messageID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
ID int
Message *waProto.WebMessageInfo
}
func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) {
msgData, err := proto.Marshal(message) msgData, err := proto.Marshal(message)
if err != nil { if err != nil {
return nil, err return nil, err
@ -245,6 +239,7 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
log: hsq.log, log: hsq.log,
UserID: userID, UserID: userID,
ConversationID: conversationID, ConversationID: conversationID,
MessageID: messageID,
Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0), Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0),
Data: msgData, Data: msgData,
}, nil }, nil
@ -252,15 +247,16 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
func (hsm *HistorySyncMessage) Insert() { func (hsm *HistorySyncMessage) Insert() {
_, err := hsm.db.Exec(` _, err := hsm.db.Exec(`
INSERT INTO history_sync_message (user_mxid, conversation_id, timestamp, data) INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4, $5)
`, hsm.UserID, hsm.ConversationID, hsm.Timestamp, hsm.Data) ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data)
if err != nil { if err != nil {
hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err) hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err)
} }
} }
func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*WrappedWebMessageInfo) { func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
whereClauses := "" whereClauses := ""
args := []interface{}{userID, conversationID} args := []interface{}{userID, conversationID}
argNum := 3 argNum := 3
@ -284,10 +280,10 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
var msgID int
var msgData []byte var msgData []byte
for rows.Next() { for rows.Next() {
err := rows.Scan(&msgID, &msgData) err := rows.Scan(&msgData)
if err != nil { if err != nil {
hsq.log.Error("Database scan failed: %v", err) hsq.log.Error("Database scan failed: %v", err)
continue continue
@ -298,21 +294,20 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
hsq.log.Errorf("Failed to unmarshal history sync message: %v", err) hsq.log.Errorf("Failed to unmarshal history sync message: %v", err)
continue continue
} }
messages = append(messages, &WrappedWebMessageInfo{ messages = append(messages, historySyncMsg.Message)
ID: msgID,
Message: historySyncMsg.Message,
})
} }
return return
} }
func (hsq *HistorySyncQuery) DeleteMessages(messages []*WrappedWebMessageInfo) error { func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
messageIDs := make([]string, len(messages)) whereClauses := []string{}
preparedStatementArgs := []interface{}{userID, conversationID}
for i, msg := range messages { for i, msg := range messages {
messageIDs[i] = strconv.Itoa(msg.ID) whereClauses = append(whereClauses, fmt.Sprintf("(user_mxid=$1 AND conversation_id=$2 AND message_id=$%d)", i+3))
preparedStatementArgs = append(preparedStatementArgs, msg.GetKey().GetId())
} }
_, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(messageIDs, ","))) _, err := hsq.db.Exec(fmt.Sprintf(deleteMessages, strings.Join(whereClauses, " OR ")), preparedStatementArgs...)
return err return err
} }

View file

@ -24,7 +24,6 @@ func init() {
unread_count INTEGER, unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id), PRIMARY KEY (user_mxid, conversation_id),
UNIQUE (conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
) )
@ -34,14 +33,15 @@ func init() {
} }
_, err = tx.Exec(` _, err = tx.Exec(`
CREATE TABLE history_sync_message ( CREATE TABLE history_sync_message (
id SERIAL PRIMARY KEY,
user_mxid TEXT, user_mxid TEXT,
conversation_id TEXT, conversation_id TEXT,
message_id TEXT,
timestamp TIMESTAMP, timestamp TIMESTAMP,
data BYTEA, data BYTEA,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
) )
`) `)
if err != nil { if err != nil {
@ -65,7 +65,6 @@ func init() {
unread_count INTEGER, unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id), PRIMARY KEY (user_mxid, conversation_id),
UNIQUE (conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
) )
@ -75,14 +74,15 @@ func init() {
} }
_, err = tx.Exec(` _, err = tx.Exec(`
CREATE TABLE history_sync_message ( CREATE TABLE history_sync_message (
id INTEGER PRIMARY KEY,
user_mxid TEXT, user_mxid TEXT,
conversation_id TEXT, conversation_id TEXT,
message_id TEXT,
timestamp DATETIME, timestamp DATETIME,
data BLOB, data BLOB,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
) )
`) `)
if err != nil { if err != nil {

View file

@ -134,7 +134,7 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
break break
} }
var msgs []*database.WrappedWebMessageInfo var msgs []*waProto.WebMessageInfo
if len(toBackfill) <= req.MaxBatchEvents { if len(toBackfill) <= req.MaxBatchEvents {
msgs = toBackfill msgs = toBackfill
toBackfill = toBackfill[0:0] toBackfill = toBackfill[0:0]
@ -152,11 +152,11 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID) user.log.Debugf("Finished backfilling %d messages in %s", len(allMsgs), portal.Key.JID)
if len(insertionEventIds) > 0 { if len(insertionEventIds) > 0 {
portal.sendPostBackfillDummy( portal.sendPostBackfillDummy(
time.Unix(int64(allMsgs[len(allMsgs)-1].Message.GetMessageTimestamp()), 0), time.Unix(int64(allMsgs[len(allMsgs)-1].GetMessageTimestamp()), 0),
insertionEventIds[0]) insertionEventIds[0])
} }
user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs)) user.log.Debugf("Deleting %d history sync messages after backfilling", len(allMsgs))
err := user.bridge.DB.HistorySyncQuery.DeleteMessages(allMsgs) err := user.bridge.DB.HistorySyncQuery.DeleteMessages(user.MXID, conv.ConversationID, allMsgs)
if err != nil { if err != nil {
user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err) user.log.Warnf("Failed to delete %d history sync messages after backfilling: %v", len(allMsgs), err)
} }
@ -227,7 +227,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History
continue continue
} }
message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg) message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), msg.Message.GetKey().GetId(), msg)
if err != nil { if err != nil {
user.log.Warnf("Failed to save message %s in %s. Error: %+v", msg.Message.Key.Id, conv.GetId(), err) user.log.Warnf("Failed to save message %s in %s. Error: %+v", msg.Message.Key.Id, conv.GetId(), err)
continue continue
@ -306,11 +306,11 @@ var (
MSC2716Marker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType} MSC2716Marker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType}
) )
func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMessageInfo) []id.EventID { func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID {
var historyBatch, newBatch mautrix.ReqBatchSend var historyBatch, newBatch mautrix.ReqBatchSend
var historyBatchInfos, newBatchInfos []*wrappedInfo var historyBatchInfos, newBatchInfos []*wrappedInfo
firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].Message.GetMessageTimestamp()), 0) firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0)
historyBatch.StateEventsAtStart = make([]*event.Event, 0) historyBatch.StateEventsAtStart = make([]*event.Event, 0)
newBatch.StateEventsAtStart = make([]*event.Event, 0) newBatch.StateEventsAtStart = make([]*event.Event, 0)
@ -365,7 +365,7 @@ func (portal *Portal) backfill(source *User, messages []*database.WrappedWebMess
portal.log.Infofln("Processing history sync with %d messages", len(messages)) portal.log.Infofln("Processing history sync with %d messages", len(messages))
// The messages are ordered newest to oldest, so iterate them in reverse order. // The messages are ordered newest to oldest, so iterate them in reverse order.
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
webMsg := messages[i].Message webMsg := messages[i]
msgType := getMessageType(webMsg.GetMessage()) msgType := getMessageType(webMsg.GetMessage())
if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" { if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
if msgType != "ignore" { if msgType != "ignore" {