backfill: perform batch finish in transaction

This commit is contained in:
Sumner Evans 2022-05-12 17:56:40 -06:00
parent c1bf0e6555
commit f3f6d88e55
No known key found for this signature in database
GPG key ID: 8904527AB50022FD
7 changed files with 105 additions and 80 deletions

View file

@ -209,7 +209,7 @@ func (handler *CommandHandler) CommandSetRelay(ce *CommandEvent) {
ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ce.User.MXID
ce.Portal.Update()
ce.Portal.Update(nil)
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
}
}
@ -225,7 +225,7 @@ func (handler *CommandHandler) CommandUnsetRelay(ce *CommandEvent) {
ce.Reply("Only admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ""
ce.Portal.Update()
ce.Portal.Update(nil)
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
}
}
@ -447,7 +447,7 @@ func (handler *CommandHandler) CommandCreate(ce *CommandEvent) {
portal.Encrypted = true
}
portal.Update()
portal.Update(nil)
portal.UpdateBridgeInfo()
ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)

View file

@ -178,16 +178,26 @@ func (msg *Message) Scan(row Scannable) *Message {
return msg
}
func (msg *Message) Insert() {
func (msg *Message) Insert(txn *sql.Tx) {
var sender interface{} = msg.Sender
// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
if msg.Sender.IsEmpty() {
sender = ""
}
_, err := msg.db.Exec(`INSERT INTO message
query := `
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
args := []interface{}{
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID,
}
var err error
if txn != nil {
_, err = txn.Exec(query, args...)
} else {
_, err = msg.db.Exec(query, args...)
}
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
@ -202,11 +212,18 @@ func (msg *Message) MarkSent(ts time.Time) {
}
}
func (msg *Message) UpdateMXID(mxid id.EventID, newType MessageType, newError MessageErrorType) {
func (msg *Message) UpdateMXID(txn *sql.Tx, mxid id.EventID, newType MessageType, newError MessageErrorType) {
msg.MXID = mxid
msg.Type = newType
msg.Error = newError
_, err := msg.db.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6", mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
query := "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
args := []interface{}{mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID}
var err error
if txn != nil {
_, err = txn.Exec(query, args...)
} else {
_, err = msg.db.Exec(query, args...)
}
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}

View file

@ -191,9 +191,21 @@ func (portal *Portal) Insert() {
}
}
func (portal *Portal) Update() {
_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10 WHERE jid=$11 AND receiver=$12",
portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver)
func (portal *Portal) Update(txn *sql.Tx) {
query := `
UPDATE portal
SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8, relay_user_id=$9, expiration_time=$10
WHERE jid=$11 AND receiver=$12
`
args := []interface{}{
portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, portal.Key.JID, portal.Key.Receiver,
}
var err error
if txn != nil {
_, err = txn.Exec(query, args...)
} else {
_, err = portal.db.Exec(query, args...)
}
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
}

View file

@ -17,6 +17,7 @@
package main
import (
"database/sql"
"fmt"
"time"
@ -154,7 +155,7 @@ func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Bac
if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
portal.ExpirationTime = *conv.EphemeralExpiration
portal.Update()
portal.Update(nil)
}
user.backfillInChunks(req, conv, portal)
@ -233,7 +234,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
msg.Timestamp = conv.LastMessageTimestamp
msg.Sent = true
msg.Type = database.MsgFake
msg.Insert()
msg.Insert(nil)
return
}
@ -561,9 +562,24 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
portal.log.Errorln("Error batch sending messages:", err)
return nil
} else {
portal.finishBatch(resp.EventIDs, infos)
portal.NextBatchID = resp.NextBatchID
portal.Update()
txn, err := portal.bridge.DB.Begin()
if err != nil {
portal.log.Errorln("Failed to start transaction to save batch messages:", err)
return nil
}
// Do the following block in the transaction
{
portal.finishBatch(txn, resp.EventIDs, infos)
portal.NextBatchID = resp.NextBatchID
portal.Update(txn)
}
err = txn.Commit()
if err != nil {
portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
return nil
}
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(source, resp.EventIDs, infos)
}
@ -654,48 +670,27 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
}, nil
}
func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*wrappedInfo) {
if len(eventIDs) != len(infos) {
portal.log.Errorfln("Length of event IDs (%d) and message infos (%d) doesn't match! Using slow path for mapping event IDs", len(eventIDs), len(infos))
infoMap := make(map[types.MessageID]*wrappedInfo, len(infos))
for _, info := range infos {
infoMap[info.ID] = info
func (portal *Portal) finishBatch(txn *sql.Tx, eventIDs []id.EventID, infos []*wrappedInfo) {
for i, info := range infos {
if info == nil {
continue
}
for _, eventID := range eventIDs {
if evt, err := portal.MainIntent().GetEvent(portal.MXID, eventID); err != nil {
portal.log.Warnfln("Failed to get event %s to register it in the database: %v", eventID, err)
} else if msgID, ok := evt.Content.Raw[backfillIDField].(string); !ok {
portal.log.Warnfln("Event %s doesn't include the WhatsApp message ID", eventID)
} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
eventID := eventIDs[i]
portal.markHandled(txn, nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
if info.ExpiresIn > 0 {
if info.ExpirationStart > 0 {
remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
} else {
portal.finishBatchEvt(info, eventID)
portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
portal.MarkDisappearing(eventID, info.ExpiresIn, false)
}
}
} else {
for i := 0; i < len(infos); i++ {
portal.finishBatchEvt(infos[i], eventIDs[i])
}
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
}
}
func (portal *Portal) finishBatchEvt(info *wrappedInfo, eventID id.EventID) {
if info == nil {
return
}
portal.markHandled(nil, info.MessageInfo, eventID, true, false, info.Type, info.Error)
if info.ExpiresIn > 0 {
if info.ExpirationStart > 0 {
remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds()
portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds))
portal.MarkDisappearing(eventID, uint32(remainingSeconds), true)
} else {
portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn)
portal.MarkDisappearing(eventID, info.ExpiresIn, false)
}
}
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
}
func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEventId id.EventID) {
@ -717,7 +712,7 @@ func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEv
msg.Timestamp = lastTimestamp.Add(1 * time.Second)
msg.Sent = true
msg.Type = database.MsgFake
msg.Insert()
msg.Insert(nil)
}
// endregion

View file

@ -74,7 +74,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
if portal != nil && !portal.Encrypted {
mx.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID)
portal.Encrypted = true
portal.Update()
portal.Update(nil)
if portal.IsPrivateChat() {
err := mx.as.BotIntent().EnsureJoined(portal.MXID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client})
if err != nil {
@ -211,7 +211,7 @@ func (mx *MatrixHandler) createPrivatePortalFromInvite(roomID id.RoomID, inviter
mx.as.StateStore.SetMembership(roomID, mx.bridge.Bot.UserID, event.MembershipJoin)
portal.Encrypted = true
}
portal.Update()
portal.Update(nil)
portal.UpdateBridgeInfo()
_, _ = intent.SendNotice(roomID, "Private chat portal created")
}

View file

@ -19,6 +19,7 @@ package main
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
@ -482,7 +483,7 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
return portal.convertGroupInviteMessage(intent, info, waMsg.GetGroupInviteMessage())
case waMsg.ProtocolMessage != nil && waMsg.ProtocolMessage.GetType() == waProto.ProtocolMessage_EPHEMERAL_SETTING:
portal.ExpirationTime = waMsg.ProtocolMessage.GetEphemeralExpiration()
portal.Update()
portal.Update(nil)
return &ConvertedMessage{
Intent: intent,
Type: event.EventMessage,
@ -498,7 +499,7 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
func (portal *Portal) UpdateGroupDisappearingMessages(sender *types.JID, timestamp time.Time, timer uint32) {
portal.ExpirationTime = timer
portal.Update()
portal.Update(nil)
intent := portal.MainIntent()
if sender != nil {
intent = portal.bridge.GetPuppetByJID(sender.ToNonAD()).IntentFor(portal)
@ -676,7 +677,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
Reason: "The undecryptable message was actually the deletion of another message",
})
existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
}
} else {
portal.log.Warnfln("Unhandled message: %+v (%s)", evt.Info, msgType)
@ -684,7 +685,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
Reason: "The undecryptable message contained an unsupported message type",
})
existingMsg.UpdateMXID("net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError)
}
return
}
@ -702,7 +703,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa
return false
}
func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, error database.MessageErrorType) *database.Message {
func (portal *Portal) markHandled(txn *sql.Tx, msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
if msg == nil {
msg = portal.bridge.DB.Message.New()
msg.Chat = portal.Key
@ -712,13 +713,13 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
msg.Sender = info.Sender
msg.Sent = isSent
msg.Type = msgType
msg.Error = error
msg.Error = errType
if info.IsIncomingBroadcast() {
msg.BroadcastListJID = info.Chat
}
msg.Insert()
msg.Insert(txn)
} else {
msg.UpdateMXID(mxid, msgType, error)
msg.UpdateMXID(txn, mxid, msgType, errType)
}
if recent {
@ -726,7 +727,7 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
index := portal.recentlyHandledIndex
portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
portal.recentlyHandledLock.Unlock()
portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, error}
portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, errType}
}
return msg
}
@ -747,13 +748,13 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
return portal.getMessagePuppet(user, info).IntentFor(portal)
}
func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, error database.MessageErrorType) {
portal.markHandled(existing, message, mxid, true, true, msgType, error)
func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, errType database.MessageErrorType) {
portal.markHandled(nil, existing, message, mxid, true, true, msgType, errType)
portal.sendDeliveryReceipt(mxid)
var suffix string
if error == database.MsgErrDecryptionFailed {
if errType == database.MsgErrDecryptionFailed {
suffix = "(undecryptable message error notice)"
} else if error == database.MsgErrMediaNotFound {
} else if errType == database.MsgErrMediaNotFound {
suffix = "(media not found notice)"
}
portal.log.Debugfln("Handled message %s (%s) -> %s %s", message.ID, msgType, mxid, suffix)
@ -1019,7 +1020,7 @@ func (portal *Portal) UpdateMatrixRoom(user *User, groupInfo *types.GroupInfo) b
update = portal.UpdateAvatar(user, types.EmptyJID, false) || update
}
if update {
portal.Update()
portal.Update(nil)
portal.UpdateBridgeInfo()
}
return true
@ -1311,7 +1312,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
return err
}
portal.MXID = resp.RoomID
portal.Update()
portal.Update(nil)
portal.bridge.portalsLock.Lock()
portal.bridge.portalsByMXID[portal.MXID] = portal
portal.bridge.portalsLock.Unlock()
@ -1329,7 +1330,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
if groupInfo != nil {
if groupInfo.IsEphemeral {
portal.ExpirationTime = groupInfo.DisappearingTimer
portal.Update()
portal.Update(nil)
}
portal.SyncParticipants(user, groupInfo)
if groupInfo.IsAnnounce {
@ -1360,7 +1361,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i
portal.log.Errorln("Failed to send dummy event to mark portal creation:", err)
} else {
portal.FirstEventID = firstEventResp.EventID
portal.Update()
portal.Update(nil)
}
if user.bridge.Config.Bridge.HistorySync.Backfill && backfill {
@ -2358,7 +2359,7 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) {
return
}
portal.log.Debugfln("Successfully edited %s -> %s after retry notification for %s", msg.MXID, resp.EventID, retry.MessageID)
msg.UpdateMXID(resp.EventID, database.MsgNormal, database.MsgNoError)
msg.UpdateMXID(nil, resp.EventID, database.MsgNormal, database.MsgNoError)
}
func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey []byte) (bool, error) {
@ -2835,7 +2836,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
}
portal.MarkDisappearing(evt.ID, portal.ExpirationTime, true)
info := portal.generateMessageInfo(sender)
dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgNormal, database.MsgNoError)
portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
ts, err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
if err != nil {
@ -2879,7 +2880,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
}
info := portal.generateMessageInfo(sender)
dbMsg := portal.markHandled(nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError)
portal.upsertReaction(nil, target.JID, sender.JID, evt.ID, info.ID)
portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
ts, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)
@ -3293,6 +3294,6 @@ func (portal *Portal) HandleMatrixMeta(sender *User, evt *event.Event) {
portal.Avatar = newID
portal.AvatarURL = content.URL
portal.UpdateBridgeInfo()
portal.Update()
portal.Update(nil)
}
}

View file

@ -280,7 +280,7 @@ func (puppet *Puppet) updatePortalAvatar() {
}
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.Update()
portal.Update(nil)
})
}
@ -293,7 +293,7 @@ func (puppet *Puppet) updatePortalName() {
}
}
portal.Name = puppet.Displayname
portal.Update()
portal.Update(nil)
})
}