diff --git a/commands.go b/commands.go index b959cc7..255d22e 100644 --- a/commands.go +++ b/commands.go @@ -873,7 +873,7 @@ func (handler *CommandHandler) CommandBackfill(ce *CommandEvent) { return } } - backfillMessages := ce.Portal.bridge.DB.BackfillQuery.NewWithValues(ce.User.MXID, database.BackfillImmediate, 0, &ce.Portal.Key, nil, nil, batchSize, -1, batchDelay) + backfillMessages := ce.Portal.bridge.DB.Backfill.NewWithValues(ce.User.MXID, database.BackfillImmediate, 0, &ce.Portal.Key, nil, nil, batchSize, -1, batchDelay) backfillMessages.Insert() ce.User.BackfillQueue.ReCheckQueue <- true diff --git a/database/database.go b/database/database.go index 15c39a9..8fe1150 100644 --- a/database/database.go +++ b/database/database.go @@ -49,9 +49,10 @@ type Database struct { Message *MessageQuery Reaction *ReactionQuery - DisappearingMessage *DisappearingMessageQuery - BackfillQuery *BackfillQuery - HistorySyncQuery *HistorySyncQuery + DisappearingMessage *DisappearingMessageQuery + Backfill *BackfillQuery + HistorySync *HistorySyncQuery + MediaBackfillRequest *MediaBackfillRequestQuery } func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { @@ -89,14 +90,18 @@ func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { db: db, log: db.log.Sub("DisappearingMessage"), } - db.BackfillQuery = &BackfillQuery{ + db.Backfill = &BackfillQuery{ db: db, log: db.log.Sub("Backfill"), } - db.HistorySyncQuery = &HistorySyncQuery{ + db.HistorySync = &HistorySyncQuery{ db: db, log: db.log.Sub("HistorySync"), } + db.MediaBackfillRequest = &MediaBackfillRequestQuery{ + db: db, + log: db.log.Sub("MediaBackfillRequest"), + } db.SetMaxOpenConns(cfg.MaxOpenConns) db.SetMaxIdleConns(cfg.MaxIdleConns) diff --git a/database/mediabackfillrequest.go b/database/mediabackfillrequest.go new file mode 100644 index 0000000..59155a9 --- /dev/null +++ b/database/mediabackfillrequest.go @@ -0,0 +1,124 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2022 Tulir Asokan, Sumner Evans +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "database/sql" + "errors" + + _ "github.com/mattn/go-sqlite3" + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type MediaBackfillRequestStatus int + +const ( + MediaBackfillRequestStatusNotRequested MediaBackfillRequestStatus = iota + MediaBackfillRequestStatusSuccess + MediaBackfillRequestStatusFailed +) + +type MediaBackfillRequestQuery struct { + db *Database + log log.Logger +} + +type MediaBackfillRequest struct { + db *Database + log log.Logger + + UserID id.UserID + PortalKey *PortalKey + EventID id.EventID + Status MediaBackfillRequestStatus + Error string +} + +func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest { + return &MediaBackfillRequest{ + db: mbrq.db, + log: mbrq.log, + PortalKey: &PortalKey{}, + } +} + +func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID) *MediaBackfillRequest { + return &MediaBackfillRequest{ + db: mbrq.db, + log: mbrq.log, + UserID: userID, + PortalKey: portalKey, + EventID: eventID, + } +} + +const ( + getMediaBackfillRequestsForUser = ` + SELECT user_mxid, portal_jid, portal_receiver, event_id, status, error + FROM media_backfill_requests + WHERE user_mxid=$1 + ` +) + +func (mbr *MediaBackfillRequest) Upsert() { + _, err := mbr.db.Exec(` + INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, status, error) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id) + DO UPDATE SET + status=EXCLUDED.status, + error=EXCLUDED.error + `, + mbr.UserID, + mbr.PortalKey.JID.String(), + mbr.PortalKey.Receiver.String(), + mbr.EventID, + mbr.Status, + mbr.Error) + if err != nil { + mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err) + } +} + +func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest { + err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.Status, &mbr.Error) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + mbr.log.Errorln("Database scan failed:", err) + } + return nil + } + return mbr +} + +func (mbr *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) { + rows, err := mbr.db.Query(getMediaBackfillRequestsForUser, userID) + defer rows.Close() + if err != nil || rows == nil { + return nil + } + for rows.Next() { + requests = append(requests, mbr.newMediaBackfillRequest().Scan(rows)) + } + return +} + +func (mbr *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) error { + _, err := mbr.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID) + return err +} diff --git a/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go b/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go new file mode 100644 index 0000000..5cfa6e4 --- /dev/null +++ b/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go @@ -0,0 +1,25 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error { + _, err := tx.Exec(` + CREATE TABLE media_backfill_requests ( + user_mxid TEXT, + portal_jid TEXT, + portal_receiver TEXT, + event_id TEXT, + status INTEGER, + error TEXT, + + PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE + ) + `) + return err + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 4366655..43eeaf8 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -40,7 +40,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 42 +const NumberOfUpgrades = 43 var upgrades [NumberOfUpgrades]upgrade diff --git a/historysync.go b/historysync.go index 27fef3f..84a3e61 100644 --- a/historysync.go +++ b/historysync.go @@ -53,7 +53,7 @@ func (user *User) handleHistorySyncsLoop() { reCheckQueue := make(chan bool, 1) // Start the backfill queue. user.BackfillQueue = &BackfillQueue{ - BackfillQuery: user.bridge.DB.BackfillQuery, + BackfillQuery: user.bridge.DB.Backfill, ImmediateBackfillRequests: make(chan *database.Backfill, 1), DeferredBackfillRequests: make(chan *database.Backfill, 1), ReCheckQueue: make(chan bool, 1), @@ -82,7 +82,7 @@ func (user *User) handleHistorySyncsLoop() { func (user *User) handleBackfillRequestsLoop(backfillRequests chan *database.Backfill) { for req := range backfillRequests { user.log.Infofln("Handling backfill request %s", req) - conv := user.bridge.DB.HistorySyncQuery.GetConversation(user.MXID, req.Portal) + conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, req.Portal) if conv == nil { user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String()) continue @@ -133,7 +133,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor user.log.Debugfln("Limiting backfill to end at %v", end) } } - allMsgs := user.bridge.DB.HistorySyncQuery.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, req.TimeEnd, req.MaxTotalEvents) + allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, req.TimeEnd, req.MaxTotalEvents) sendDisappearedNotice := false // If expired messages are on, and a notice has not been sent to this chat @@ -211,7 +211,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor insertionEventIds[0]) } user.log.Debugfln("Deleting %d history sync messages after backfilling (queue ID: %d)", len(allMsgs), req.QueueID) - err := user.bridge.DB.HistorySyncQuery.DeleteMessages(user.MXID, conv.ConversationID, allMsgs) + err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs) if err != nil { user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err) } @@ -255,7 +255,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History } portal := user.GetPortalByJID(jid) - historySyncConversation := user.bridge.DB.HistorySyncQuery.NewConversationWithValues( + historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues( user.MXID, conv.GetId(), &portal.Key, @@ -291,7 +291,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History continue } - message, err := user.bridge.DB.HistorySyncQuery.NewMessageWithValues(user.MXID, conv.GetId(), wmi.GetKey().GetId(), rawMsg) + message, err := user.bridge.DB.HistorySync.NewMessageWithValues(user.MXID, conv.GetId(), wmi.GetKey().GetId(), rawMsg) if err != nil { user.log.Warnfln("Failed to save message %s in %s. Error: %+v", wmi.GetKey().Id, conv.GetId(), err) continue @@ -308,7 +308,7 @@ func (user *User) handleHistorySync(reCheckQueue chan bool, evt *waProto.History return } - nMostRecent := user.bridge.DB.HistorySyncQuery.GetNMostRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations) + nMostRecent := user.bridge.DB.HistorySync.GetNMostRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations) if len(nMostRecent) > 0 { // Find the portals for all of the conversations. portals := []*Portal{} @@ -348,7 +348,7 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 { func (user *User) EnqueueImmedateBackfills(portals []*Portal) { for priority, portal := range portals { maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents - initialBackfill := user.bridge.DB.BackfillQuery.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, nil, maxMessages, maxMessages, 0) + initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, nil, maxMessages, maxMessages, 0) initialBackfill.Insert() } } @@ -362,7 +362,7 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) { startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo) startDate = &startDaysAgo } - backfillMessages := user.bridge.DB.BackfillQuery.NewWithValues( + backfillMessages := user.bridge.DB.Backfill.NewWithValues( user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, nil, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay) backfillMessages.Insert() } @@ -375,7 +375,7 @@ func (user *User) EnqueueForwardBackfills(portals []*Portal) { if lastMsg == nil { continue } - backfill := user.bridge.DB.BackfillQuery.NewWithValues( + backfill := user.bridge.DB.Backfill.NewWithValues( user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, nil, -1, -1, 0) backfill.Insert() } diff --git a/portal.go b/portal.go index d219757..689c753 100644 --- a/portal.go +++ b/portal.go @@ -1224,8 +1224,8 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, i // before creating the matrix room if errors.Is(err, whatsmeow.ErrNotInGroup) { user.log.Debugfln("Skipping creating matrix room for %s because the user is not a participant", portal.Key.JID) - user.bridge.DB.BackfillQuery.DeleteAllForPortal(user.MXID, portal.Key) - user.bridge.DB.HistorySyncQuery.DeleteAllMessagesForPortal(user.MXID, portal.Key) + user.bridge.DB.Backfill.DeleteAllForPortal(user.MXID, portal.Key) + user.bridge.DB.HistorySync.DeleteAllMessagesForPortal(user.MXID, portal.Key) return err } else if err != nil { portal.log.Warnfln("Failed to get group info through %s: %v", user.JID, err) diff --git a/user.go b/user.go index db16bdb..3d6e630 100644 --- a/user.go +++ b/user.go @@ -428,9 +428,9 @@ func (user *User) DeleteSession() { } // Delete all of the backfill and history sync data. - user.bridge.DB.BackfillQuery.DeleteAll(user.MXID) - user.bridge.DB.HistorySyncQuery.DeleteAllConversations(user.MXID) - user.bridge.DB.HistorySyncQuery.DeleteAllMessages(user.MXID) + user.bridge.DB.Backfill.DeleteAll(user.MXID) + user.bridge.DB.HistorySync.DeleteAllConversations(user.MXID) + user.bridge.DB.HistorySync.DeleteAllMessages(user.MXID) } func (user *User) IsConnected() bool {