diff --git a/database/mediabackfillrequest.go b/database/mediabackfillrequest.go index 45a5245..d8a6cb5 100644 --- a/database/mediabackfillrequest.go +++ b/database/mediabackfillrequest.go @@ -45,6 +45,7 @@ type MediaBackfillRequest struct { UserID id.UserID PortalKey *PortalKey EventID id.EventID + MediaKey []byte Status MediaBackfillRequestStatus Error string } @@ -57,20 +58,21 @@ func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillR } } -func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID) *MediaBackfillRequest { +func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest { return &MediaBackfillRequest{ db: mbrq.db, log: mbrq.log, UserID: userID, PortalKey: portalKey, EventID: eventID, + MediaKey: mediaKey, Status: MediaBackfillRequestStatusNotRequested, } } const ( getMediaBackfillRequestsForUser = ` - SELECT user_mxid, portal_jid, portal_receiver, event_id, status, error + SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error FROM media_backfill_requests WHERE user_mxid=$1 AND status=0 @@ -79,17 +81,18 @@ const ( func (mbr *MediaBackfillRequest) Upsert() { _, err := mbr.db.Exec(` - INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, status, error) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id) DO UPDATE SET + media_key=EXCLUDED.media_key, status=EXCLUDED.status, - error=EXCLUDED.error - `, + error=EXCLUDED.error`, mbr.UserID, mbr.PortalKey.JID.String(), mbr.PortalKey.Receiver.String(), mbr.EventID, + mbr.MediaKey, mbr.Status, mbr.Error) if err != nil { @@ -98,7 +101,7 @@ func (mbr *MediaBackfillRequest) Upsert() { } func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest { - err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.Status, &mbr.Error) + err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error) if err != nil { if !errors.Is(err, sql.ErrNoRows) { mbr.log.Errorln("Database scan failed:", err) diff --git a/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go b/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go index 5cfa6e4..2470ffa 100644 --- a/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go +++ b/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go @@ -12,6 +12,7 @@ func init() { portal_jid TEXT, portal_receiver TEXT, event_id TEXT, + media_key BYTEA, status INTEGER, error TEXT, diff --git a/historysync.go b/historysync.go index 9bf9266..fed28d8 100644 --- a/historysync.go +++ b/historysync.go @@ -112,7 +112,7 @@ func (user *User) dailyMediaRequestLoop() { // Send all of the media backfill requests for the user at once for _, req := range mediaBackfillRequests { portal := user.GetPortalByJID(req.PortalKey.JID) - _, err := portal.requestMediaRetry(user, req.EventID) + _, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey) if err != nil { user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID) req.Status = database.MediaBackfillRequestStatusRequestFailed @@ -121,6 +121,7 @@ func (user *User) dailyMediaRequestLoop() { user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID) req.Status = database.MediaBackfillRequestStatusRequested } + req.MediaKey = nil req.Upsert() } @@ -588,7 +589,7 @@ func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, i portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID) } case config.MediaRequestMethodLocalTime: - req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i]) + req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey) req.Upsert() } } diff --git a/matrix.go b/matrix.go index 90309a1..2b9f899 100644 --- a/matrix.go +++ b/matrix.go @@ -491,7 +491,7 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) { content := evt.Content.AsReaction() if strings.Contains(content.RelatesTo.Key, "retry") || strings.HasPrefix(content.RelatesTo.Key, "\u267b") { // ♻️ - if retryRequested, _ := portal.requestMediaRetry(user, content.RelatesTo.EventID); retryRequested { + if retryRequested, _ := portal.requestMediaRetry(user, content.RelatesTo.EventID, nil); retryRequested { _, _ = portal.MainIntent().RedactEvent(portal.MXID, evt.ID, mautrix.ReqRedact{ Reason: "requested media from phone", }) diff --git a/portal.go b/portal.go index b791eae..4759c6b 100644 --- a/portal.go +++ b/portal.go @@ -2346,7 +2346,7 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) { msg.UpdateMXID(resp.EventID, database.MsgNormal, database.MsgNoError) } -func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) (bool, error) { +func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey []byte) (bool, error) { msg := portal.bridge.DB.Message.GetByMXID(eventID) if msg == nil { err := errors.New(fmt.Sprintf("%s requested a media retry for unknown event %s", user.MXID, eventID)) @@ -2358,13 +2358,17 @@ func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) (bool, e return false, err } - evt, err := portal.fetchMediaRetryEvent(msg) - if err != nil { - portal.log.Warnfln("Can't send media retry request for %s: %v", msg.JID, err) - return true, nil + // If the media key is not provided, grab it from the event in Matrix + if mediaKey == nil { + evt, err := portal.fetchMediaRetryEvent(msg) + if err != nil { + portal.log.Warnfln("Can't send media retry request for %s: %v", msg.JID, err) + return true, nil + } + mediaKey = evt.Media.Key } - err = user.Client.SendMediaRetryReceipt(&types.MessageInfo{ + err := user.Client.SendMediaRetryReceipt(&types.MessageInfo{ ID: msg.JID, MessageSource: types.MessageSource{ IsFromMe: msg.Sender.User == user.JID.User, @@ -2372,7 +2376,7 @@ func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID) (bool, e Sender: msg.Sender, Chat: portal.Key.JID, }, - }, evt.Media.Key) + }, mediaKey) if err != nil { portal.log.Warnfln("Failed to send media retry request for %s: %v", msg.JID, err) } else {