diff --git a/database/backfillqueue.go b/database/backfillqueue.go new file mode 100644 index 0000000..7f012e3 --- /dev/null +++ b/database/backfillqueue.go @@ -0,0 +1,149 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2021 Tulir Asokan, Sumner Evans +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "database/sql" + "errors" + "time" + + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type BackfillType int + +const ( + BackfillImmediate BackfillType = 0 + BackfillDeferred = 1 +) + +type BackfillQuery struct { + db *Database + log log.Logger +} + +func (bq *BackfillQuery) New() *Backfill { + return &Backfill{ + db: bq.db, + log: bq.log, + Portal: &PortalKey{}, + } +} + +func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, timeEnd *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill { + return &Backfill{ + db: bq.db, + log: bq.log, + UserID: userID, + BackfillType: backfillType, + Priority: priority, + Portal: portal, + TimeStart: timeStart, + TimeEnd: timeEnd, + MaxBatchEvents: maxBatchEvents, + MaxTotalEvents: maxTotalEvents, + BatchDelay: batchDelay, + } +} + +const ( + getNextBackfillQuery = ` + SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay + FROM backfill_queue + WHERE user_mxid=$1 + AND type=$2 + ORDER BY priority, queue_id + LIMIT 1 + ` +) + +/// Returns the next backfill to perform +func (bq *BackfillQuery) GetNext(userID id.UserID, backfillType BackfillType) (backfill *Backfill) { + rows, err := bq.db.Query(getNextBackfillQuery, userID, backfillType) + defer rows.Close() + if err != nil || rows == nil { + bq.log.Error(err) + return + } + if rows.Next() { + backfill = bq.New().Scan(rows) + } + return +} + +func (bq *BackfillQuery) DeleteAll(userID id.UserID) error { + _, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID) + return err +} + +type Backfill struct { + db *Database + log log.Logger + + // Fields + QueueID int + UserID id.UserID + BackfillType BackfillType + Priority int + Portal *PortalKey + TimeStart *time.Time + TimeEnd *time.Time + MaxBatchEvents int + MaxTotalEvents int + BatchDelay int +} + +func (b *Backfill) Scan(row Scannable) *Backfill { + err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.TimeEnd, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + b.log.Errorln("Database scan failed:", err) + } + return nil + } + return b +} + +func (b *Backfill) Insert() { + rows, err := b.db.Query(` + INSERT INTO backfill_queue + (user_mxid, type, priority, portal_jid, portal_receiver, time_start, time_end, max_batch_events, max_total_events, batch_delay) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING queue_id + `, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.TimeEnd, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay) + defer rows.Close() + if err != nil || !rows.Next() { + b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err) + return + } + err = rows.Scan(&b.QueueID) + if err != nil { + b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err) + } +} + +func (b *Backfill) Delete() { + if b.QueueID == 0 { + b.log.Errorf("Cannot delete backfill without queue_id. Maybe it wasn't actually inserted in the database?") + return + } + _, err := b.db.Exec("DELETE FROM backfill_queue WHERE queue_id=$1", b.QueueID) + if err != nil { + b.log.Warnfln("Failed to delete %s/%s: %v", b.BackfillType, b.Priority, err) + } +} diff --git a/database/database.go b/database/database.go index 7640871..cb64a12 100644 --- a/database/database.go +++ b/database/database.go @@ -46,6 +46,8 @@ type Database struct { Reaction *ReactionQuery DisappearingMessage *DisappearingMessageQuery + BackfillQuery *BackfillQuery + HistorySyncQuery *HistorySyncQuery } func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { @@ -83,6 +85,14 @@ func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { db: db, log: db.log.Sub("DisappearingMessage"), } + db.BackfillQuery = &BackfillQuery{ + db: db, + log: db.log.Sub("Backfill"), + } + db.HistorySyncQuery = &HistorySyncQuery{ + db: db, + log: db.log.Sub("HistorySync"), + } db.SetMaxOpenConns(cfg.MaxOpenConns) db.SetMaxIdleConns(cfg.MaxIdleConns) diff --git a/database/historysync.go b/database/historysync.go new file mode 100644 index 0000000..04821a2 --- /dev/null +++ b/database/historysync.go @@ -0,0 +1,293 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2022 Tulir Asokan, Sumner Evans +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "database/sql" + "errors" + "fmt" + "time" + + waProto "go.mau.fi/whatsmeow/binary/proto" + "google.golang.org/protobuf/proto" + + _ "github.com/mattn/go-sqlite3" + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type HistorySyncQuery struct { + db *Database + log log.Logger +} + +type HistorySyncConversation struct { + db *Database + log log.Logger + + UserID id.UserID + ConversationID string + PortalKey *PortalKey + LastMessageTimestamp time.Time + MuteEndTime time.Time + Archived bool + Pinned uint32 + DisappearingMode waProto.DisappearingMode_DisappearingModeInitiator + EndOfHistoryTransferType waProto.Conversation_ConversationEndOfHistoryTransferType + EphemeralExpiration *uint32 + MarkedAsUnread bool + UnreadCount uint32 +} + +func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation { + return &HistorySyncConversation{ + db: hsq.db, + log: hsq.log, + PortalKey: &PortalKey{}, + } +} + +func (hsq *HistorySyncQuery) NewConversationWithValues( + userID id.UserID, + conversationID string, + portalKey *PortalKey, + lastMessageTimestamp, + muteEndTime uint64, + archived bool, + pinned uint32, + disappearingMode waProto.DisappearingMode_DisappearingModeInitiator, + endOfHistoryTransferType waProto.Conversation_ConversationEndOfHistoryTransferType, + ephemeralExpiration *uint32, + markedAsUnread bool, + unreadCount uint32) *HistorySyncConversation { + return &HistorySyncConversation{ + db: hsq.db, + log: hsq.log, + UserID: userID, + ConversationID: conversationID, + PortalKey: portalKey, + LastMessageTimestamp: time.Unix(int64(lastMessageTimestamp), 0), + MuteEndTime: time.Unix(int64(muteEndTime), 0), + Archived: archived, + Pinned: pinned, + DisappearingMode: disappearingMode, + EndOfHistoryTransferType: endOfHistoryTransferType, + EphemeralExpiration: ephemeralExpiration, + MarkedAsUnread: markedAsUnread, + UnreadCount: unreadCount, + } +} + +const ( + getNMostRecentConversations = ` + SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count + FROM history_sync_conversation + WHERE user_mxid=$1 + ORDER BY last_message_timestamp DESC + LIMIT $2 + ` + getConversationByPortal = ` + SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count + FROM history_sync_conversation + WHERE user_mxid=$1 + AND portal_jid=$2 + AND portal_receiver=$3 + ` +) + +func (hsc *HistorySyncConversation) Upsert() { + _, err := hsc.db.Exec(` + INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ON CONFLICT (user_mxid, conversation_id) + DO UPDATE SET + portal_jid=EXCLUDED.portal_jid, + portal_receiver=EXCLUDED.portal_receiver, + last_message_timestamp=EXCLUDED.last_message_timestamp, + archived=EXCLUDED.archived, + pinned=EXCLUDED.pinned, + mute_end_time=EXCLUDED.mute_end_time, + disappearing_mode=EXCLUDED.disappearing_mode, + end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type, + ephemeral_expiration=EXCLUDED.ephemeral_expiration, + marked_as_unread=EXCLUDED.marked_as_unread, + unread_count=EXCLUDED.unread_count + `, + hsc.UserID, + hsc.ConversationID, + hsc.PortalKey.JID.String(), + hsc.PortalKey.Receiver.String(), + hsc.LastMessageTimestamp, + hsc.Archived, + hsc.Pinned, + hsc.MuteEndTime, + hsc.DisappearingMode, + hsc.EndOfHistoryTransferType, + hsc.EphemeralExpiration, + hsc.MarkedAsUnread, + hsc.UnreadCount) + if err != nil { + hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err) + } +} + +func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation { + err := row.Scan( + &hsc.UserID, + &hsc.ConversationID, + &hsc.PortalKey.JID, + &hsc.PortalKey.Receiver, + &hsc.LastMessageTimestamp, + &hsc.Archived, + &hsc.Pinned, + &hsc.MuteEndTime, + &hsc.DisappearingMode, + &hsc.EndOfHistoryTransferType, + &hsc.EphemeralExpiration, + &hsc.MarkedAsUnread, + &hsc.UnreadCount) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + hsc.log.Errorln("Database scan failed:", err) + } + return nil + } + return hsc +} + +func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) { + rows, err := hsq.db.Query(getNMostRecentConversations, userID, n) + defer rows.Close() + if err != nil || rows == nil { + return nil + } + for rows.Next() { + conversations = append(conversations, hsq.NewConversation().Scan(rows)) + } + return +} + +func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey *PortalKey) (conversation *HistorySyncConversation) { + rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver) + defer rows.Close() + if err != nil || rows == nil { + return nil + } + if rows.Next() { + conversation = hsq.NewConversation().Scan(rows) + } + return +} + +func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) error { + _, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID) + return err +} + +const ( + getMessagesBetween = ` + SELECT data + FROM history_sync_message + WHERE user_mxid=$1 + AND conversation_id=$2 + %s + ORDER BY timestamp DESC + %s + ` +) + +type HistorySyncMessage struct { + db *Database + log log.Logger + + UserID id.UserID + ConversationID string + Timestamp time.Time + Data []byte +} + +func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversationID string, message *waProto.HistorySyncMsg) (*HistorySyncMessage, error) { + msgData, err := proto.Marshal(message) + if err != nil { + return nil, err + } + return &HistorySyncMessage{ + db: hsq.db, + log: hsq.log, + UserID: userID, + ConversationID: conversationID, + Timestamp: time.Unix(int64(message.Message.GetMessageTimestamp()), 0), + Data: msgData, + }, nil +} + +func (hsm *HistorySyncMessage) Insert() { + _, err := hsm.db.Exec(` + INSERT INTO history_sync_message (user_mxid, conversation_id, timestamp, data) + VALUES ($1, $2, $3, $4) + `, hsm.UserID, hsm.ConversationID, hsm.Timestamp, hsm.Data) + if err != nil { + hsm.log.Warnfln("Failed to insert history sync message %s/%s: %v", hsm.ConversationID, hsm.Timestamp, err) + } +} + +func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) { + whereClauses := "" + args := []interface{}{userID, conversationID} + argNum := 3 + if startTime != nil { + whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum) + args = append(args, startTime) + argNum++ + } + if endTime != nil { + whereClauses += fmt.Sprintf(" AND timestamp <= $%d", argNum) + args = append(args, endTime) + } + + limitClause := "" + if limit > 0 { + limitClause = fmt.Sprintf("LIMIT %d", limit) + } + + rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...) + defer rows.Close() + if err != nil || rows == nil { + return nil + } + var msgData []byte + for rows.Next() { + err := rows.Scan(&msgData) + if err != nil { + hsq.log.Error("Database scan failed: %v", err) + continue + } + var historySyncMsg waProto.HistorySyncMsg + err = proto.Unmarshal(msgData, &historySyncMsg) + if err != nil { + hsq.log.Errorf("Failed to unmarshal history sync message: %v", err) + continue + } + messages = append(messages, historySyncMsg.Message) + } + return +} + +func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) error { + _, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID) + return err +} diff --git a/database/upgrades/2022-03-15-prioritized-backfill.go b/database/upgrades/2022-03-15-prioritized-backfill.go new file mode 100644 index 0000000..96fc9cd --- /dev/null +++ b/database/upgrades/2022-03-15-prioritized-backfill.go @@ -0,0 +1,54 @@ +package upgrades + +import "database/sql" + +func init() { + upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error { + _, err := tx.Exec(` + CREATE TABLE backfill_queue ( + queue_id INTEGER PRIMARY KEY, + user_mxid TEXT, + type INTEGER NOT NULL, + priority INTEGER NOT NULL, + portal_jid VARCHAR(255), + portal_receiver VARCHAR(255), + time_start TIMESTAMP, + time_end TIMESTAMP, + max_batch_events INTEGER NOT NULL, + max_total_events INTEGER, + batch_delay INTEGER, + + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + + // The queue_id needs to auto-increment every insertion. For SQLite, + // INTEGER PRIMARY KEY is an alias for the ROWID, so it will + // auto-increment. See https://sqlite.org/lang_createtable.html#rowid + // For Postgres, we have to manually add the sequence. + if ctx.dialect == Postgres { + _, err = tx.Exec(` + CREATE SEQUENCE backfill_queue_queue_id_seq + START WITH 1 + OWNED BY backfill_queue.queue_id + `) + if err != nil { + return err + } + _, err = tx.Exec(` + ALTER TABLE backfill_queue + ALTER COLUMN queue_id + SET DEFAULT nextval('backfill_queue_queue_id_seq'::regclass) + `) + if err != nil { + return err + } + } + + return err + }} +} diff --git a/database/upgrades/2022-03-18-historysync-store.go b/database/upgrades/2022-03-18-historysync-store.go new file mode 100644 index 0000000..bf53ab1 --- /dev/null +++ b/database/upgrades/2022-03-18-historysync-store.go @@ -0,0 +1,93 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error { + if ctx.dialect == Postgres { + _, err := tx.Exec(` + CREATE TABLE history_sync_conversation ( + user_mxid TEXT, + conversation_id TEXT, + portal_jid TEXT, + portal_receiver TEXT, + last_message_timestamp TIMESTAMP, + archived BOOLEAN, + pinned INTEGER, + mute_end_time TIMESTAMP, + disappearing_mode INTEGER, + end_of_history_transfer_type INTEGER, + ephemeral_expiration INTEGER, + marked_as_unread BOOLEAN, + unread_count INTEGER, + + PRIMARY KEY (user_mxid, conversation_id), + UNIQUE (conversation_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE + ) + `) + if err != nil { + return err + } + _, err = tx.Exec(` + CREATE TABLE history_sync_message ( + user_mxid TEXT, + conversation_id TEXT, + timestamp TIMESTAMP, + data BYTEA, + + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + } else if ctx.dialect == SQLite { + _, err := tx.Exec(` + CREATE TABLE history_sync_conversation ( + user_mxid TEXT, + conversation_id TEXT, + portal_jid TEXT, + portal_receiver TEXT, + last_message_timestamp DATETIME, + archived INTEGER, + pinned INTEGER, + mute_end_time DATETIME, + disappearing_mode INTEGER, + end_of_history_transfer_type INTEGER, + ephemeral_expiration INTEGER, + marked_as_unread INTEGER, + unread_count INTEGER, + + PRIMARY KEY (user_mxid, conversation_id), + UNIQUE (conversation_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE + ) + `) + if err != nil { + return err + } + _, err = tx.Exec(` + CREATE TABLE history_sync_message ( + user_mxid TEXT, + conversation_id TEXT, + timestamp DATETIME, + data BLOB, + + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (conversation_id) REFERENCES history_sync_conversation(conversation_id) ON DELETE CASCADE + ) + `) + if err != nil { + return err + } + } + + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index f4b1ded..98044a4 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -40,7 +40,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 39 +const NumberOfUpgrades = 41 var upgrades [NumberOfUpgrades]upgrade