Merge branch 'split-forward-backfill'

This commit is contained in:
Tulir Asokan 2022-04-29 17:30:09 +03:00
commit b8e7c17d5c
5 changed files with 111 additions and 88 deletions

View file

@ -20,6 +20,7 @@ import (
"time" "time"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
) )
@ -32,19 +33,16 @@ type BackfillQueue struct {
log log.Logger log log.Logger
} }
// Immediate backfills should happen first, then deferred backfills and lastly // RunLoop fetches backfills from the database, prioritizing immediate and forward backfills
// media backfills.
func (bq *BackfillQueue) RunLoop(user *User) { func (bq *BackfillQueue) RunLoop(user *User) {
for { for {
if immediate := bq.BackfillQuery.GetNext(user.MXID, database.BackfillImmediate); immediate != nil { if backfill := bq.BackfillQuery.GetNext(user.MXID); backfill != nil {
bq.ImmediateBackfillRequests <- immediate if backfill.BackfillType == database.BackfillImmediate || backfill.BackfillType == database.BackfillForward {
immediate.MarkDone() bq.ImmediateBackfillRequests <- backfill
} else if backfill := bq.BackfillQuery.GetNext(user.MXID, database.BackfillDeferred); backfill != nil { } else {
bq.DeferredBackfillRequests <- backfill bq.DeferredBackfillRequests <- backfill
}
backfill.MarkDone() backfill.MarkDone()
} else if mediaBackfill := bq.BackfillQuery.GetNext(user.MXID, database.BackfillMedia); mediaBackfill != nil {
bq.DeferredBackfillRequests <- mediaBackfill
mediaBackfill.MarkDone()
} else { } else {
select { select {
case <-bq.ReCheckQueue: case <-bq.ReCheckQueue:

View file

@ -30,14 +30,17 @@ type BackfillType int
const ( const (
BackfillImmediate BackfillType = 0 BackfillImmediate BackfillType = 0
BackfillDeferred = 1 BackfillForward BackfillType = 100
BackfillMedia = 2 BackfillDeferred BackfillType = 200
BackfillMedia BackfillType = 300
) )
func (bt BackfillType) String() string { func (bt BackfillType) String() string {
switch bt { switch bt {
case BackfillImmediate: case BackfillImmediate:
return "IMMEDIATE" return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred: case BackfillDeferred:
return "DEFERRED" return "DEFERRED"
case BackfillMedia: case BackfillMedia:
@ -80,16 +83,15 @@ const (
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay
FROM backfill_queue FROM backfill_queue
WHERE user_mxid=$1 WHERE user_mxid=$1
AND type=$2
AND completed_at IS NULL AND completed_at IS NULL
ORDER BY priority, queue_id ORDER BY type, priority, queue_id
LIMIT 1 LIMIT 1
` `
) )
// GetNext returns the next backfill to perform // GetNext returns the next backfill to perform
func (bq *BackfillQuery) GetNext(userID id.UserID, backfillType BackfillType) (backfill *Backfill) { func (bq *BackfillQuery) GetNext(userID id.UserID) (backfill *Backfill) {
rows, err := bq.db.Query(getNextBackfillQuery, userID, backfillType) rows, err := bq.db.Query(getNextBackfillQuery, userID)
defer rows.Close() defer rows.Close()
if err != nil || rows == nil { if err != nil || rows == nil {
bq.log.Error(err) bq.log.Error(err)

View file

@ -0,0 +1,20 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
UPDATE backfill_queue
SET type=CASE
WHEN type=1 THEN 200
WHEN type=2 THEN 300
ELSE type
END
WHERE type=1 OR type=2
`)
return err
}}
}

View file

@ -40,7 +40,7 @@ type upgrade struct {
fn upgradeFunc fn upgradeFunc
} }
const NumberOfUpgrades = 41 const NumberOfUpgrades = 42
var upgrades [NumberOfUpgrades]upgrade var upgrades [NumberOfUpgrades]upgrade

View file

@ -128,12 +128,12 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac
portal.Update() portal.Update()
} }
user.createOrUpdatePortalAndBackfillWithLock(req, conv, portal) user.backfillInChunks(req, conv, portal)
} }
} }
} }
func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) { func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) {
portal.backfillLock.Lock() portal.backfillLock.Lock()
defer portal.backfillLock.Unlock() defer portal.backfillLock.Unlock()
@ -141,6 +141,22 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
return return
} }
var forwardPrevID id.EventID
if req.BackfillType == database.BackfillForward {
// TODO this overrides the TimeStart set when enqueuing the backfill
// maybe the enqueue should instead include the prev event ID
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
forwardPrevID = lastMessage.MXID
start := lastMessage.Timestamp.Add(1 * time.Second)
req.TimeStart = &start
} else {
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key)
if firstMessage != nil && (req.TimeEnd == nil || firstMessage.Timestamp.Before(*req.TimeEnd)) {
end := firstMessage.Timestamp.Add(-1 * time.Second)
req.TimeEnd = &end
user.log.Debugfln("Limiting backfill to end at %v", end)
}
}
allMsgs := user.bridge.DB.HistorySyncQuery.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, req.TimeEnd, req.MaxTotalEvents) allMsgs := user.bridge.DB.HistorySyncQuery.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, req.TimeEnd, req.MaxTotalEvents)
if len(allMsgs) == 0 { if len(allMsgs) == 0 {
user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID) user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID)
@ -161,7 +177,7 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
var insertionEventIds []id.EventID var insertionEventIds []id.EventID
for len(toBackfill) > 0 { for len(toBackfill) > 0 {
var msgs []*waProto.WebMessageInfo var msgs []*waProto.WebMessageInfo
if len(toBackfill) <= req.MaxBatchEvents { if len(toBackfill) <= req.MaxBatchEvents || req.MaxBatchEvents < 0 {
msgs = toBackfill msgs = toBackfill
toBackfill = nil toBackfill = nil
} else { } else {
@ -172,7 +188,10 @@ func (user *User) createOrUpdatePortalAndBackfillWithLock(req *database.Backfill
if len(msgs) > 0 { if len(msgs) > 0 {
time.Sleep(time.Duration(req.BatchDelay) * time.Second) time.Sleep(time.Duration(req.BatchDelay) * time.Second)
user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID) user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID)
insertionEventIds = append(insertionEventIds, portal.backfill(user, msgs)...) resp := portal.backfill(user, msgs, req.BackfillType == database.BackfillForward, forwardPrevID)
if resp != nil {
insertionEventIds = append(insertionEventIds, resp.BaseInsertionEventID)
}
} }
} }
user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID) user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID)
@ -297,6 +316,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History
// Enqueue immediate backfills for the most recent messages first. // Enqueue immediate backfills for the most recent messages first.
user.EnqueueImmedateBackfills(portals) user.EnqueueImmedateBackfills(portals)
case waProto.HistorySync_FULL, waProto.HistorySync_RECENT: case waProto.HistorySync_FULL, waProto.HistorySync_RECENT:
user.EnqueueForwardBackfills(portals)
// Enqueue deferred backfills as configured. // Enqueue deferred backfills as configured.
user.EnqueueDeferredBackfills(portals) user.EnqueueDeferredBackfills(portals)
user.EnqueueMediaBackfills(portals) user.EnqueueMediaBackfills(portals)
@ -340,6 +360,18 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
} }
} }
func (user *User) EnqueueForwardBackfills(portals []*Portal) {
for priority, portal := range portals {
lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key)
if lastMsg == nil {
continue
}
backfill := user.bridge.DB.BackfillQuery.NewWithValues(
user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, nil, -1, -1, 0)
backfill.Insert()
}
}
func (user *User) EnqueueMediaBackfills(portals []*Portal) { func (user *User) EnqueueMediaBackfills(portals []*Portal) {
numPortals := len(portals) numPortals := len(portals)
for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Media { for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Media {
@ -367,16 +399,26 @@ var (
HistorySyncMarker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType} HistorySyncMarker = event.Type{Type: "org.matrix.msc2716.marker", Class: event.MessageEventType}
) )
func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo) []id.EventID { func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward bool, prevEventID id.EventID) *mautrix.RespBatchSend {
var historyBatch, newBatch mautrix.ReqBatchSend var req mautrix.ReqBatchSend
var historyBatchInfos, newBatchInfos []*wrappedInfo var infos []*wrappedInfo
firstMsgTimestamp := time.Unix(int64(messages[len(messages)-1].GetMessageTimestamp()), 0) if !isForward {
if portal.FirstEventID != "" || portal.NextBatchID != "" {
req.PrevEventID = portal.FirstEventID
req.BatchID = portal.NextBatchID
} else {
portal.log.Warnfln("Can't backfill %d messages through %s to chat: first event ID not known", len(messages), source.MXID)
return nil
}
} else {
req.PrevEventID = prevEventID
}
historyBatch.StateEventsAtStart = make([]*event.Event, 0) beforeFirstMessageTimestampMillis := (int64(messages[len(messages)-1].GetMessageTimestamp()) * 1000) - 1
newBatch.StateEventsAtStart = make([]*event.Event, 0) req.StateEventsAtStart = make([]*event.Event, 0)
addedMembers := make(map[id.UserID]*event.MemberEventContent) addedMembers := make(map[id.UserID]struct{})
addMember := func(puppet *Puppet) { addMember := func(puppet *Puppet) {
if _, alreadyAdded := addedMembers[puppet.MXID]; alreadyAdded { if _, alreadyAdded := addedMembers[puppet.MXID]; alreadyAdded {
return return
@ -389,41 +431,23 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
} }
inviteContent := content inviteContent := content
inviteContent.Membership = event.MembershipInvite inviteContent.Membership = event.MembershipInvite
historyBatch.StateEventsAtStart = append(historyBatch.StateEventsAtStart, &event.Event{ req.StateEventsAtStart = append(req.StateEventsAtStart, &event.Event{
Type: event.StateMember, Type: event.StateMember,
Sender: portal.MainIntent().UserID, Sender: portal.MainIntent().UserID,
StateKey: &mxid, StateKey: &mxid,
Timestamp: firstMsgTimestamp.UnixMilli(), Timestamp: beforeFirstMessageTimestampMillis,
Content: event.Content{Parsed: &inviteContent}, Content: event.Content{Parsed: &inviteContent},
}, &event.Event{ }, &event.Event{
Type: event.StateMember, Type: event.StateMember,
Sender: puppet.MXID, Sender: puppet.MXID,
StateKey: &mxid, StateKey: &mxid,
Timestamp: firstMsgTimestamp.UnixMilli(), Timestamp: beforeFirstMessageTimestampMillis,
Content: event.Content{Parsed: &content}, Content: event.Content{Parsed: &content},
}) })
addedMembers[puppet.MXID] = &content addedMembers[puppet.MXID] = struct{}{}
} }
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key) portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward)
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
var historyMaxTs, newMinTs time.Time
if portal.FirstEventID != "" || portal.NextBatchID != "" {
historyBatch.PrevEventID = portal.FirstEventID
historyBatch.BatchID = portal.NextBatchID
if firstMessage == nil && lastMessage == nil {
historyMaxTs = time.Now()
} else {
historyMaxTs = firstMessage.Timestamp
}
}
if lastMessage != nil {
newBatch.PrevEventID = lastMessage.MXID
newMinTs = lastMessage.Timestamp
}
portal.log.Debugfln("Processing backfill with %d messages", len(messages))
// The messages are ordered newest to oldest, so iterate them in reverse order. // The messages are ordered newest to oldest, so iterate them in reverse order.
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
webMsg := messages[i] webMsg := messages[i]
@ -446,15 +470,6 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
if info == nil { if info == nil {
continue continue
} }
var batch *mautrix.ReqBatchSend
var infos *[]*wrappedInfo
if !historyMaxTs.IsZero() && info.Timestamp.Before(historyMaxTs) {
batch, infos = &historyBatch, &historyBatchInfos
} else if !newMinTs.IsZero() && info.Timestamp.After(newMinTs) {
batch, infos = &newBatch, &newBatchInfos
} else {
continue
}
if webMsg.GetPushName() != "" && webMsg.GetPushName() != "-" { if webMsg.GetPushName() != "" && webMsg.GetPushName() != "-" {
existingContact, _ := source.Client.Store.Contacts.GetContact(info.Sender) existingContact, _ := source.Client.Store.Contacts.GetContact(info.Sender)
if !existingContact.Found || existingContact.PushName == "" { if !existingContact.Found || existingContact.PushName == "" {
@ -484,13 +499,18 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
if len(converted.ReplyTo) > 0 { if len(converted.ReplyTo) > 0 {
portal.SetReply(converted.Content, converted.ReplyTo) portal.SetReply(converted.Content, converted.ReplyTo)
} }
err := portal.appendBatchEvents(converted, info, webMsg.GetEphemeralStartTimestamp(), &batch.Events, infos) err := portal.appendBatchEvents(converted, info, webMsg.GetEphemeralStartTimestamp(), &req.Events, &infos)
if err != nil { if err != nil {
portal.log.Errorfln("Error handling message %s during backfill: %v", info.ID, err) portal.log.Errorfln("Error handling message %s during backfill: %v", info.ID, err)
} }
} }
portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events))
if (len(historyBatch.Events) > 0 && len(historyBatch.BatchID) == 0) || len(newBatch.Events) > 0 { if len(req.Events) == 0 {
return nil
}
if len(req.BatchID) == 0 || isForward {
portal.log.Debugln("Sending a dummy event to avoid forward extremity errors with backfill") portal.log.Debugln("Sending a dummy event to avoid forward extremity errors with backfill")
_, err := portal.MainIntent().SendMessageEvent(portal.MXID, PreBackfillDummyEvent, struct{}{}) _, err := portal.MainIntent().SendMessageEvent(portal.MXID, PreBackfillDummyEvent, struct{}{})
if err != nil { if err != nil {
@ -498,33 +518,16 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo)
} }
} }
var insertionEventIds []id.EventID resp, err := portal.MainIntent().BatchSend(portal.MXID, &req)
if err != nil {
if len(historyBatch.Events) > 0 && len(historyBatch.PrevEventID) > 0 { portal.log.Errorln("Error batch sending messages:", err)
portal.log.Infofln("Sending %d historical messages...", len(historyBatch.Events)) return nil
historyResp, err := portal.MainIntent().BatchSend(portal.MXID, &historyBatch) } else {
if err != nil { portal.finishBatch(resp.EventIDs, infos)
portal.log.Errorln("Error sending batch of historical messages:", err) portal.NextBatchID = resp.NextBatchID
} else { portal.Update()
insertionEventIds = append(insertionEventIds, historyResp.BaseInsertionEventID) return resp
portal.finishBatch(historyResp.EventIDs, historyBatchInfos)
portal.NextBatchID = historyResp.NextBatchID
portal.Update()
}
} }
if len(newBatch.Events) > 0 && len(newBatch.PrevEventID) > 0 {
portal.log.Infofln("Sending %d new messages...", len(newBatch.Events))
newResp, err := portal.MainIntent().BatchSend(portal.MXID, &newBatch)
if err != nil {
portal.log.Errorln("Error sending batch of new messages:", err)
} else {
insertionEventIds = append(insertionEventIds, newResp.BaseInsertionEventID)
portal.finishBatch(newResp.EventIDs, newBatchInfos)
}
}
return insertionEventIds
} }
func (portal *Portal) parseWebMessageInfo(source *User, webMsg *waProto.WebMessageInfo) *types.MessageInfo { func (portal *Portal) parseWebMessageInfo(source *User, webMsg *waProto.WebMessageInfo) *types.MessageInfo {