forked from MirrorHub/mautrix-whatsapp
Reroute broadcast list messages to correct DM portal. Fixes #411
This commit is contained in:
parent
5e04577081
commit
ca5fcc42ba
7 changed files with 82 additions and 35 deletions
|
@ -43,27 +43,27 @@ func (mq *MessageQuery) New() *Message {
|
|||
|
||||
const (
|
||||
getAllMessagesQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2
|
||||
`
|
||||
getMessageByJIDQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
|
||||
`
|
||||
getMessageByMXIDQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE mxid=$1
|
||||
`
|
||||
getLastMessageInChatQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
|
||||
`
|
||||
getFirstMessageInChatQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
|
||||
`
|
||||
getMessagesBetweenQuery = `
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
|
||||
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true ORDER BY timestamp ASC
|
||||
`
|
||||
)
|
||||
|
@ -133,7 +133,8 @@ type Message struct {
|
|||
Timestamp time.Time
|
||||
Sent bool
|
||||
|
||||
DecryptionError bool
|
||||
DecryptionError bool
|
||||
BroadcastListJID types.JID
|
||||
}
|
||||
|
||||
func (msg *Message) IsFakeMXID() bool {
|
||||
|
@ -146,7 +147,7 @@ func (msg *Message) IsFakeJID() bool {
|
|||
|
||||
func (msg *Message) Scan(row Scannable) *Message {
|
||||
var ts int64
|
||||
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError)
|
||||
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError, &msg.BroadcastListJID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
msg.log.Errorln("Database scan failed:", err)
|
||||
|
@ -166,9 +167,9 @@ func (msg *Message) Insert() {
|
|||
sender = ""
|
||||
}
|
||||
_, err := msg.db.Exec(`INSERT INTO message
|
||||
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError)
|
||||
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error, broadcast_list_jid)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError, msg.BroadcastListJID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
|
||||
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
|
||||
return err
|
||||
}}
|
||||
}
|
|
@ -39,7 +39,7 @@ type upgrade struct {
|
|||
fn upgradeFunc
|
||||
}
|
||||
|
||||
const NumberOfUpgrades = 32
|
||||
const NumberOfUpgrades = 33
|
||||
|
||||
var upgrades [NumberOfUpgrades]upgrade
|
||||
|
||||
|
|
2
go.mod
2
go.mod
|
@ -8,7 +8,7 @@ require (
|
|||
github.com/mattn/go-sqlite3 v1.14.9
|
||||
github.com/prometheus/client_golang v1.11.0
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058
|
||||
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164
|
||||
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
|
||||
google.golang.org/protobuf v1.27.1
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
|
|
4
go.sum
4
go.sum
|
@ -139,8 +139,8 @@ github.com/tidwall/sjson v1.2.3 h1:5+deguEhHSEjmuICXZ21uSSsXotWMA0orU783+Z7Cp8=
|
|||
github.com/tidwall/sjson v1.2.3/go.mod h1:5WdjKx3AQMvCJ4RG6/2UYT7dLrGvJUV1x4jdTAyGvZs=
|
||||
go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910 h1:9FFhG0OmkuMau5UEaTgiUQ+7cSbtbOQ7hiWKdN8OI3I=
|
||||
go.mau.fi/libsignal v0.0.0-20211109153248-a67163214910/go.mod h1:AufGrvVh+00Nc07Jm4hTquh7yleZyn20tKJI2wCPAKg=
|
||||
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058 h1:5z1PUeFB4XaTtUzXM2n8nK6c+Uu+Mkzm5JliSTCsFL0=
|
||||
go.mau.fi/whatsmeow v0.0.0-20211221173950-fbdc16e29058/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
|
||||
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164 h1:uA2QfpClxXnrRzkAy08UXJ5P7Wc/QiQFLKZSVAgXg5w=
|
||||
go.mau.fi/whatsmeow v0.0.0-20211225184405-612b42c0c164/go.mod h1:8jUjOAi3xtGubxcZgG8uSHpAdyQXBRbWAfxkctX/4y4=
|
||||
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
|
|
42
portal.go
42
portal.go
|
@ -43,21 +43,20 @@ import (
|
|||
"golang.org/x/image/webp"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"maunium.net/go/mautrix/format"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
)
|
||||
|
||||
|
@ -452,6 +451,12 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
|
|||
}
|
||||
converted := portal.convertMessage(intent, source, &evt.Info, evt.Message)
|
||||
if converted != nil {
|
||||
if evt.Info.IsIncomingBroadcast() {
|
||||
if converted.Extra == nil {
|
||||
converted.Extra = map[string]interface{}{}
|
||||
}
|
||||
converted.Extra["fi.mau.whatsapp.source_broadcast_list"] = evt.Info.Chat.String()
|
||||
}
|
||||
var eventID id.EventID
|
||||
if existingMsg != nil {
|
||||
converted.Content.SetEdit(existingMsg.MXID)
|
||||
|
@ -522,6 +527,9 @@ func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo
|
|||
msg.Sender = info.Sender
|
||||
msg.Sent = isSent
|
||||
msg.DecryptionError = decryptionError
|
||||
if info.IsIncomingBroadcast() {
|
||||
msg.BroadcastListJID = info.Chat
|
||||
}
|
||||
msg.Insert()
|
||||
} else {
|
||||
msg.UpdateMXID(mxid, decryptionError)
|
||||
|
@ -2288,13 +2296,25 @@ func (portal *Portal) HandleMatrixReadReceipt(sender *User, eventID id.EventID,
|
|||
}
|
||||
groupedMessages := make(map[types.JID][]types.MessageID)
|
||||
for _, msg := range messages {
|
||||
if !msg.IsFakeJID() {
|
||||
groupedMessages[msg.Sender] = append(groupedMessages[msg.Sender], msg.JID)
|
||||
}
|
||||
var key types.JID
|
||||
if msg.IsFakeJID() || msg.Sender.User == sender.JID.User {
|
||||
// Don't send read receipts for own messages or fake messages
|
||||
continue
|
||||
} else if !portal.IsPrivateChat() {
|
||||
key = msg.Sender
|
||||
} else if !msg.BroadcastListJID.IsEmpty() {
|
||||
key = msg.BroadcastListJID
|
||||
} // else: blank key (participant field isn't needed in direct chat read receipts)
|
||||
groupedMessages[key] = append(groupedMessages[key], msg.JID)
|
||||
}
|
||||
portal.log.Debugfln("Sending read receipts by %s: %v", sender.JID, groupedMessages)
|
||||
for messageSender, ids := range groupedMessages {
|
||||
err := sender.Client.MarkRead(ids, receiptTimestamp, portal.Key.JID, messageSender)
|
||||
chatJID := portal.Key.JID
|
||||
if messageSender.Server == types.BroadcastServer {
|
||||
chatJID = messageSender
|
||||
messageSender = portal.Key.JID
|
||||
}
|
||||
err := sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to mark %v as read by %s: %v", ids, sender.JID, err)
|
||||
}
|
||||
|
|
32
user.go
32
user.go
|
@ -28,21 +28,20 @@ import (
|
|||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"go.mau.fi/whatsmeow/appstate"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
"go.mau.fi/whatsmeow/appstate"
|
||||
"go.mau.fi/whatsmeow/store"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
waLog "go.mau.fi/whatsmeow/util/log"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
)
|
||||
|
||||
|
@ -442,7 +441,7 @@ func (user *User) HandleEvent(event interface{}) {
|
|||
case *events.ChatPresence:
|
||||
go user.handleChatPresence(v)
|
||||
case *events.Message:
|
||||
portal := user.GetPortalByJID(v.Info.Chat)
|
||||
portal := user.GetPortalByMessageSource(v.Info.MessageSource)
|
||||
portal.messages <- PortalMessage{evt: v, source: user}
|
||||
case *events.CallOffer:
|
||||
user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp)
|
||||
|
@ -470,7 +469,7 @@ func (user *User) HandleEvent(event interface{}) {
|
|||
case *events.CallTerminate, *events.CallRelayLatency, *events.CallAccept, *events.UnknownCallEvent:
|
||||
// ignore
|
||||
case *events.UndecryptableMessage:
|
||||
portal := user.GetPortalByJID(v.Info.Chat)
|
||||
portal := user.GetPortalByMessageSource(v.Info.MessageSource)
|
||||
portal.messages <- PortalMessage{undecryptable: v, source: user}
|
||||
case *events.HistorySync:
|
||||
user.historySyncs <- v
|
||||
|
@ -667,6 +666,21 @@ func (user *User) handleLoggedOut(onConnect bool) {
|
|||
user.sendBridgeState(BridgeState{StateEvent: StateBadCredentials, Error: WANotLoggedIn})
|
||||
}
|
||||
|
||||
func (user *User) GetPortalByMessageSource(ms types.MessageSource) *Portal {
|
||||
jid := ms.Chat
|
||||
if ms.IsIncomingBroadcast() {
|
||||
if ms.IsFromMe {
|
||||
jid = ms.BroadcastListOwner.ToNonAD()
|
||||
} else {
|
||||
jid = ms.Sender.ToNonAD()
|
||||
}
|
||||
if jid.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
|
||||
}
|
||||
|
||||
func (user *User) GetPortalByJID(jid types.JID) *Portal {
|
||||
return user.bridge.GetPortalByJID(database.NewPortalKey(jid, user.JID))
|
||||
}
|
||||
|
@ -737,7 +751,7 @@ func (user *User) handleReceipt(receipt *events.Receipt) {
|
|||
if receipt.Type != events.ReceiptTypeRead && receipt.Type != events.ReceiptTypeReadSelf {
|
||||
return
|
||||
}
|
||||
portal := user.GetPortalByJID(receipt.Chat)
|
||||
portal := user.GetPortalByMessageSource(receipt.MessageSource)
|
||||
if portal == nil || len(portal.MXID) == 0 {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue