diff --git a/database/message.go b/database/message.go index b6ef12b..7f29968 100644 --- a/database/message.go +++ b/database/message.go @@ -43,27 +43,27 @@ func (mq *MessageQuery) New() *Message { const ( getAllMessagesQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ` getMessageByJIDQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3 ` getMessageByMXIDQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE mxid=$1 ` getLastMessageInChatQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1 ` getFirstMessageInChatQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1 ` getMessagesBetweenQuery = ` - SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true ORDER BY timestamp ASC ` ) @@ -133,7 +133,8 @@ type Message struct { Timestamp time.Time Sent bool - DecryptionError bool + DecryptionError bool + BroadcastListJID types.JID } func (msg *Message) IsFakeMXID() bool { @@ -146,7 +147,7 @@ func (msg *Message) IsFakeJID() bool { func (msg *Message) Scan(row Scannable) *Message { var ts int64 - err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError) + err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError, &msg.BroadcastListJID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { msg.log.Errorln("Database scan failed:", err) @@ -166,9 +167,9 @@ func (msg *Message) Insert() { sender = "" } _, err := msg.db.Exec(`INSERT INTO message - (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, - msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError) + (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, + msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError, msg.BroadcastListJID) if err != nil { msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) } diff --git a/database/upgrades/2021-12-25-broadcast-list-message-source.go b/database/upgrades/2021-12-25-broadcast-list-message-source.go new file mode 100644 index 0000000..11059de --- /dev/null +++ b/database/upgrades/2021-12-25-broadcast-list-message-source.go @@ -0,0 +1,12 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error { + _, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`) + return err + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 5912994..293126c 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -39,7 +39,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 32 +const NumberOfUpgrades = 33 var upgrades [NumberOfUpgrades]upgrade diff --git a/go.mod b/go.mod index d96087c..ecfaac0 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.9 github.com/prometheus/client_golang v1.11.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058 + go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164 golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d google.golang.org/protobuf v1.27.1 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b diff --git a/go.sum b/go.sum index 67ca498..e1f7778 100644 --- a/go.sum +++ b/go.sum @@ -139,8 +139,8 @@ github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8= github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs= go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I= go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg= -go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058 h1:5z1PUeFB4XaTtUzXM2n8nK6c+Uu+Mkzm5JliSTCsFL0= -go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4= +go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164 h1:uA2QfpClxXnrRzkAy08UXJ5P7Wc/QiQFLKZSVAgXg5w= +go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/portal.go b/portal.go index 1850ffb..8567065 100644 --- a/portal.go +++ b/portal.go @@ -43,21 +43,20 @@ import ( "golang.org/x/image/webp" "google.golang.org/protobuf/proto" - "maunium.net/go/mautrix/format" - - "go.mau.fi/whatsmeow" - waProto "go.mau.fi/whatsmeow/binary/proto" - "go.mau.fi/whatsmeow/types" - "go.mau.fi/whatsmeow/types/events" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" + "go.mau.fi/whatsmeow" + waProto "go.mau.fi/whatsmeow/binary/proto" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + "maunium.net/go/mautrix-whatsapp/database" ) @@ -452,6 +451,12 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { } converted := portal.convertMessage(intent, source, &evt.Info, evt.Message) if converted != nil { + if evt.Info.IsIncomingBroadcast() { + if converted.Extra == nil { + converted.Extra = map[string]interface{}{} + } + converted.Extra["fi.mau.whatsapp.source_broadcast_list"] = evt.Info.Chat.String() + } var eventID id.EventID if existingMsg != nil { converted.Content.SetEdit(existingMsg.MXID) @@ -522,6 +527,9 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo msg.Sender = info.Sender msg.Sent = isSent msg.DecryptionError = decryptionError + if info.IsIncomingBroadcast() { + msg.BroadcastListJID = info.Chat + } msg.Insert() } else { msg.UpdateMXID(mxid, decryptionError) @@ -2288,13 +2296,25 @@ func (portal *Portal) HandleMatrixReadReceipt(sender *User, eventID id.EventID, } groupedMessages := make(map[types.JID][]types.MessageID) for _, msg := range messages { - if !msg.IsFakeJID() { - groupedMessages[msg.Sender] = append(groupedMessages[msg.Sender], msg.JID) - } + var key types.JID + if msg.IsFakeJID() || msg.Sender.User == sender.JID.User { + // Don't send read receipts for own messages or fake messages + continue + } else if !portal.IsPrivateChat() { + key = msg.Sender + } else if !msg.BroadcastListJID.IsEmpty() { + key = msg.BroadcastListJID + } // else: blank key (participant field isn't needed in direct chat read receipts) + groupedMessages[key] = append(groupedMessages[key], msg.JID) } portal.log.Debugfln("Sending read receipts by %s: %v", sender.JID, groupedMessages) for messageSender, ids := range groupedMessages { - err := sender.Client.MarkRead(ids, receiptTimestamp, portal.Key.JID, messageSender) + chatJID := portal.Key.JID + if messageSender.Server == types.BroadcastServer { + chatJID = messageSender + messageSender = portal.Key.JID + } + err := sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender) if err != nil { portal.log.Warnfln("Failed to mark %v as read by %s: %v", ids, sender.JID, err) } diff --git a/user.go b/user.go index 7e067e7..e610c76 100644 --- a/user.go +++ b/user.go @@ -28,21 +28,20 @@ import ( log "maunium.net/go/maulogger/v2" - "go.mau.fi/whatsmeow/appstate" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" "go.mau.fi/whatsmeow" + "go.mau.fi/whatsmeow/appstate" "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types/events" waLog "go.mau.fi/whatsmeow/util/log" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix-whatsapp/database" ) @@ -442,7 +441,7 @@ func (user *User) HandleEvent(event interface{}) { case *events.ChatPresence: go user.handleChatPresence(v) case *events.Message: - portal := user.GetPortalByJID(v.Info.Chat) + portal := user.GetPortalByMessageSource(v.Info.MessageSource) portal.messages <- PortalMessage{evt: v, source: user} case *events.CallOffer: user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp) @@ -470,7 +469,7 @@ func (user *User) HandleEvent(event interface{}) { case *events.CallTerminate, *events.CallRelayLatency, *events.CallAccept, *events.UnknownCallEvent: // ignore case *events.UndecryptableMessage: - portal := user.GetPortalByJID(v.Info.Chat) + portal := user.GetPortalByMessageSource(v.Info.MessageSource) portal.messages <- PortalMessage{undecryptable: v, source: user} case *events.HistorySync: user.historySyncs <- v @@ -667,6 +666,21 @@ func (user *User) handleLoggedOut(onConnect bool) { user.sendBridgeState(BridgeState{StateEvent: StateBadCredentials, Error: WANotLoggedIn}) } +func (user *User) GetPortalByMessageSource(ms types.MessageSource) *Portal { + jid := ms.Chat + if ms.IsIncomingBroadcast() { + if ms.IsFromMe { + jid = ms.BroadcastListOwner.ToNonAD() + } else { + jid = ms.Sender.ToNonAD() + } + if jid.IsEmpty() { + return nil + } + } + return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID)) +} + func (user *User) GetPortalByJID(jid types.JID) *Portal { return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID)) } @@ -737,7 +751,7 @@ func (user *User) handleReceipt(receipt *events.Receipt) { if receipt.Type != events.ReceiptTypeRead && receipt.Type != events.ReceiptTypeReadSelf { return } - portal := user.GetPortalByJID(receipt.Chat) + portal := user.GetPortalByMessageSource(receipt.MessageSource) if portal == nil || len(portal.MXID) == 0 { return }