From 4bfd3bd644c37b33b72d79dd2598435386ee82e4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Nov 2022 23:09:46 +0200 Subject: [PATCH] Fix marking messages as disappearing while backfilling on SQLite --- database/disappearingmessage.go | 7 +++++-- disappear.go | 5 +++-- historysync.go | 4 ++-- portal.go | 10 +++++----- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/database/disappearingmessage.go b/database/disappearingmessage.go index 90769c1..003c792 100644 --- a/database/disappearingmessage.go +++ b/database/disappearingmessage.go @@ -112,13 +112,16 @@ func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage return msg } -func (msg *DisappearingMessage) Insert() { +func (msg *DisappearingMessage) Insert(txn dbutil.Execable) { + if txn == nil { + txn = msg.db + } var expireAt sql.NullInt64 if !msg.ExpireAt.IsZero() { expireAt.Valid = true expireAt.Int64 = msg.ExpireAt.UnixMilli() } - _, err := msg.db.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`, + _, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`, msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt) if err != nil { msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err) diff --git a/disappear.go b/disappear.go index 34736a3..4a41a8a 100644 --- a/disappear.go +++ b/disappear.go @@ -22,17 +22,18 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix-whatsapp/database" ) -func (portal *Portal) MarkDisappearing(eventID id.EventID, expiresIn uint32, startNow bool) { +func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn uint32, startNow bool) { if expiresIn == 0 || (!portal.bridge.Config.Bridge.DisappearingMessagesInGroups && portal.IsGroupChat()) { return } msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, time.Duration(expiresIn)*time.Second, startNow) - msg.Insert() + msg.Insert(txn) if startNow { go portal.sleepAndDelete(msg) } diff --git a/historysync.go b/historysync.go index 6845c29..b0edc39 100644 --- a/historysync.go +++ b/historysync.go @@ -788,10 +788,10 @@ func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, if info.ExpirationStart > 0 { remainingSeconds := time.Unix(int64(info.ExpirationStart), 0).Add(time.Duration(info.ExpiresIn) * time.Second).Sub(time.Now()).Seconds() portal.log.Debugfln("Disappearing history sync message: expires in %d, started at %d, remaining %d", info.ExpiresIn, info.ExpirationStart, int(remainingSeconds)) - portal.MarkDisappearing(eventID, uint32(remainingSeconds), true) + portal.MarkDisappearing(txn, eventID, uint32(remainingSeconds), true) } else { portal.log.Debugfln("Disappearing history sync message: expires in %d (not started)", info.ExpiresIn) - portal.MarkDisappearing(eventID, info.ExpiresIn, false) + portal.MarkDisappearing(txn, eventID, info.ExpiresIn, false) } } } diff --git a/portal.go b/portal.go index 0bcebaf..2c64cf4 100644 --- a/portal.go +++ b/portal.go @@ -743,7 +743,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { var eventID id.EventID var lastEventID id.EventID if existingMsg != nil { - portal.MarkDisappearing(existingMsg.MXID, converted.ExpiresIn, false) + portal.MarkDisappearing(nil, existingMsg.MXID, converted.ExpiresIn, false) converted.Content.SetEdit(existingMsg.MXID) } else if converted.ReplyTo != nil { portal.SetReply(converted.Content, converted.ReplyTo, false) @@ -758,7 +758,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { portal.log.Errorfln("Failed to send %s to Matrix: %v", msgID, err) } else { if editTargetMsg == nil { - portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false) + portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false) } eventID = resp.EventID lastEventID = eventID @@ -769,7 +769,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { if err != nil { portal.log.Errorfln("Failed to send caption of %s to Matrix: %v", msgID, err) } else { - portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false) + portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false) lastEventID = resp.EventID } } @@ -779,7 +779,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { if err != nil { portal.log.Errorfln("Failed to send sub-event %d of %s to Matrix: %v", index+1, msgID, err) } else { - portal.MarkDisappearing(resp.EventID, converted.ExpiresIn, false) + portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, false) lastEventID = resp.EventID } } @@ -3502,7 +3502,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing } dbMsgType := database.MsgNormal if msg.EditedMessage == nil { - portal.MarkDisappearing(origEvtID, portal.ExpirationTime, true) + portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true) } else { dbMsgType = database.MsgEdit }