Reroute broadcast list messages to correct DM portal. Fixes #411

This commit is contained in:
Tulir Asokan 2021-12-25 20:50:36 +02:00
parent 5e04577081
commit ca5fcc42ba
7 changed files with 82 additions and 35 deletions

View file

@ -43,27 +43,27 @@ func (mq *MessageQuery) New() *Message {
const (
getAllMessagesQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2
`
getMessageByJIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
`
getMessageByMXIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE mxid=$1
`
getLastMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
`
getFirstMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
`
getMessagesBetweenQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true ORDER BY timestamp ASC
`
)
@ -133,7 +133,8 @@ type Message struct {
Timestamp time.Time
Sent bool
DecryptionError bool
DecryptionError bool
BroadcastListJID types.JID
}
func (msg *Message) IsFakeMXID() bool {
@ -146,7 +147,7 @@ func (msg *Message) IsFakeJID() bool {
func (msg *Message) Scan(row Scannable) *Message {
var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError)
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError, &msg.BroadcastListJID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
@ -166,9 +167,9 @@ func (msg *Message) Insert() {
sender = ""
}
_, err := msg.db.Exec(`INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError)
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError, msg.BroadcastListJID)
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}

View file

@ -0,0 +1,12 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
return err
}}
}

View file

@ -39,7 +39,7 @@ type upgrade struct {
fn upgradeFunc
}
const NumberOfUpgrades = 32
const NumberOfUpgrades = 33
var upgrades [NumberOfUpgrades]upgrade

2
go.mod
View file

@ -8,7 +8,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.9
github.com/prometheus/client_golang v1.11.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
google.golang.org/protobuf v1.27.1
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b

4
go.sum
View file

@ -139,8 +139,8 @@ github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8=
github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs=
go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058 h1:5z1PUeFB4XaTtUzXM2n8nK6c+Uu+Mkzm5JliSTCsFL0=
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164 h1:uA2QfpClxXnrRzkAy08UXJ5P7Wc/QiQFLKZSVAgXg5w=
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View file

@ -43,21 +43,20 @@ import (
"golang.org/x/image/webp"
"google.golang.org/protobuf/proto"
"maunium.net/go/mautrix/format"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
"maunium.net/go/mautrix-whatsapp/database"
)
@ -452,6 +451,12 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
}
converted := portal.convertMessage(intent, source, &evt.Info, evt.Message)
if converted != nil {
if evt.Info.IsIncomingBroadcast() {
if converted.Extra == nil {
converted.Extra = map[string]interface{}{}
}
converted.Extra["fi.mau.whatsapp.source_broadcast_list"] = evt.Info.Chat.String()
}
var eventID id.EventID
if existingMsg != nil {
converted.Content.SetEdit(existingMsg.MXID)
@ -522,6 +527,9 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
msg.Sender = info.Sender
msg.Sent = isSent
msg.DecryptionError = decryptionError
if info.IsIncomingBroadcast() {
msg.BroadcastListJID = info.Chat
}
msg.Insert()
} else {
msg.UpdateMXID(mxid, decryptionError)
@ -2288,13 +2296,25 @@ func (portal *Portal) HandleMatrixReadReceipt(sender *User, eventID id.EventID,
}
groupedMessages := make(map[types.JID][]types.MessageID)
for _, msg := range messages {
if !msg.IsFakeJID() {
groupedMessages[msg.Sender] = append(groupedMessages[msg.Sender], msg.JID)
}
var key types.JID
if msg.IsFakeJID() || msg.Sender.User == sender.JID.User {
// Don't send read receipts for own messages or fake messages
continue
} else if !portal.IsPrivateChat() {
key = msg.Sender
} else if !msg.BroadcastListJID.IsEmpty() {
key = msg.BroadcastListJID
} // else: blank key (participant field isn't needed in direct chat read receipts)
groupedMessages[key] = append(groupedMessages[key], msg.JID)
}
portal.log.Debugfln("Sending read receipts by %s: %v", sender.JID, groupedMessages)
for messageSender, ids := range groupedMessages {
err := sender.Client.MarkRead(ids, receiptTimestamp, portal.Key.JID, messageSender)
chatJID := portal.Key.JID
if messageSender.Server == types.BroadcastServer {
chatJID = messageSender
messageSender = portal.Key.JID
}
err := sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender)
if err != nil {
portal.log.Warnfln("Failed to mark %v as read by %s: %v", ids, sender.JID, err)
}

32
user.go
View file

@ -28,21 +28,20 @@ import (
log "maunium.net/go/maulogger/v2"
"go.mau.fi/whatsmeow/appstate"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/whatsmeow"
"go.mau.fi/whatsmeow/appstate"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
waLog "go.mau.fi/whatsmeow/util/log"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database"
)
@ -442,7 +441,7 @@ func (user *User) HandleEvent(event interface{}) {
case *events.ChatPresence:
go user.handleChatPresence(v)
case *events.Message:
portal := user.GetPortalByJID(v.Info.Chat)
portal := user.GetPortalByMessageSource(v.Info.MessageSource)
portal.messages <- PortalMessage{evt: v, source: user}
case *events.CallOffer:
user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp)
@ -470,7 +469,7 @@ func (user *User) HandleEvent(event interface{}) {
case *events.CallTerminate, *events.CallRelayLatency, *events.CallAccept, *events.UnknownCallEvent:
// ignore
case *events.UndecryptableMessage:
portal := user.GetPortalByJID(v.Info.Chat)
portal := user.GetPortalByMessageSource(v.Info.MessageSource)
portal.messages <- PortalMessage{undecryptable: v, source: user}
case *events.HistorySync:
user.historySyncs <- v
@ -667,6 +666,21 @@ func (user *User) handleLoggedOut(onConnect bool) {
user.sendBridgeState(BridgeState{StateEvent: StateBadCredentials, Error: WANotLoggedIn})
}
func (user *User) GetPortalByMessageSource(ms types.MessageSource) *Portal {
jid := ms.Chat
if ms.IsIncomingBroadcast() {
if ms.IsFromMe {
jid = ms.BroadcastListOwner.ToNonAD()
} else {
jid = ms.Sender.ToNonAD()
}
if jid.IsEmpty() {
return nil
}
}
return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
}
func (user *User) GetPortalByJID(jid types.JID) *Portal {
return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
}
@ -737,7 +751,7 @@ func (user *User) handleReceipt(receipt *events.Receipt) {
if receipt.Type != events.ReceiptTypeRead && receipt.Type != events.ReceiptTypeReadSelf {
return
}
portal := user.GetPortalByJID(receipt.Chat)
portal := user.GetPortalByMessageSource(receipt.MessageSource)
if portal == nil || len(portal.MXID) == 0 {
return
}