Handle decryption errors from WhatsApp properly

This commit is contained in:
Tulir Asokan 2021-10-27 18:31:33 +03:00
parent ded2fb9799
commit b918b4f261
5 changed files with 166 additions and 61 deletions

View file

@ -40,12 +40,34 @@ func (mq *MessageQuery) New() *Message {
}
}
const (
getAllMessagesQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
WHERE chat_jid=$1 AND chat_receiver=$2
`
getMessageByJIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
`
getMessageByMXIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
WHERE mxid=$1
`
getLastMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
`
getFirstMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
`
)
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
@ -53,23 +75,19 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
}
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
}
func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
"FROM message WHERE mxid=$1", mxid)
return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
}
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
return mq.GetLastInChatBefore(chat, time.Now().Add(60 * time.Second))
return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
}
func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1",
chat.JID, chat.Receiver, maxTimestamp.Unix())
msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
if msg == nil || msg.Timestamp.IsZero() {
// Old db, we don't know what the last message is.
return nil
@ -78,13 +96,10 @@ func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Ti
}
func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1",
chat.JID, chat.Receiver)
return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
}
func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
row := mq.db.QueryRow(query, args...)
func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
if row == nil {
return nil
}
@ -101,6 +116,8 @@ type Message struct {
Sender types.JID
Timestamp time.Time
Sent bool
DecryptionError bool
}
func (msg *Message) IsFakeMXID() bool {
@ -109,7 +126,7 @@ func (msg *Message) IsFakeMXID() bool {
func (msg *Message) Scan(row Scannable) *Message {
var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent)
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError)
if err != nil {
if err != sql.ErrNoRows {
msg.log.Errorln("Database scan failed:", err)
@ -129,9 +146,9 @@ func (msg *Message) Insert() {
sender = ""
}
_, err := msg.db.Exec(`INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent)
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError)
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
@ -145,6 +162,15 @@ func (msg *Message) MarkSent() {
}
}
func (msg *Message) UpdateMXID(mxid id.EventID, stillDecryptionError bool) {
msg.MXID = mxid
msg.DecryptionError = stillDecryptionError
_, err := msg.db.Exec("UPDATE message SET mxid=$4, decryption_error=$5 WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, stillDecryptionError)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
}
func (msg *Message) Delete() {
_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {

View file

@ -0,0 +1,12 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View file

@ -39,7 +39,7 @@ type upgrade struct {
fn upgradeFunc
}
const NumberOfUpgrades = 27
const NumberOfUpgrades = 28
var upgrades [NumberOfUpgrades]upgrade

116
portal.go
View file

@ -42,6 +42,8 @@ import (
"golang.org/x/image/webp"
"google.golang.org/protobuf/proto"
"maunium.net/go/mautrix/format"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
@ -161,9 +163,15 @@ const recentlyHandledLength = 100
type PortalMessage struct {
evt *events.Message
undecryptable *events.UndecryptableMessage
source *User
}
type recentlyHandledWrapper struct {
id types.MessageID
err bool
}
type Portal struct {
*database.Portal
@ -174,7 +182,7 @@ type Portal struct {
encryptLock sync.Mutex
backfillLock sync.Mutex
recentlyHandled [recentlyHandledLength]types.MessageID
recentlyHandled [recentlyHandledLength]recentlyHandledWrapper
recentlyHandledLock sync.Mutex
recentlyHandledIndex uint8
@ -185,8 +193,6 @@ type Portal struct {
hasRelaybot *bool
}
const MaxMessageAgeToCreatePortal = 5 * 60 // 5 minutes
func (portal *Portal) syncDoublePuppetDetailsAfterCreate(source *User) {
doublePuppet := portal.bridge.GetPuppetByCustomMXID(source.MXID)
if doublePuppet == nil {
@ -210,13 +216,20 @@ func (portal *Portal) handleMessageLoop() {
}
portal.syncDoublePuppetDetailsAfterCreate(msg.source)
}
//portal.backfillLock.Lock()
if msg.evt != nil {
portal.handleMessage(msg.source, msg.evt)
//portal.backfillLock.Unlock()
} else if msg.undecryptable != nil {
portal.handleUndecryptableMessage(msg.source, msg.undecryptable)
} else {
portal.log.Warnln("Unexpected PortalMessage with no message: %+v", msg)
}
}
}
func (portal *Portal) shouldCreateRoom(msg PortalMessage) bool {
if msg.undecryptable != nil {
return true
}
waMsg := msg.evt.Message
supportedMessages := []interface{}{
waMsg.Conversation,
@ -295,6 +308,30 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User,
}
}
const UndecryptableMessage = "Decrypting message from WhatsApp failed, waiting for sender to re-send... " +
"([learn more](https://faq.whatsapp.com/general/security-and-privacy/seeing-waiting-for-this-message-this-may-take-a-while))"
func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.UndecryptableMessage) {
if len(portal.MXID) == 0 {
portal.log.Warnln("handleUndecryptableMessage called even though portal.MXID is empty")
return
} else if portal.isRecentlyHandled(evt.Info.ID, true) {
portal.log.Debugfln("Not handling %s (undecryptable): message was recently handled", evt.Info.ID)
return
} else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, evt.Info.ID); existingMsg != nil {
portal.log.Debugfln("Not handling %s (undecryptable): message is duplicate", evt.Info.ID)
return
}
intent := portal.getMessageIntent(source, &evt.Info)
content := format.RenderMarkdown(UndecryptableMessage, true, false)
content.MsgType = event.MsgNotice
resp, err := portal.sendMessage(intent, event.EventMessage, &content, evt.Info.Timestamp.UnixMilli())
if err != nil {
portal.log.Errorln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
}
portal.finishHandling(nil, &evt.Info, resp.EventID, true)
}
func (portal *Portal) handleMessage(source *User, evt *events.Message) {
if len(portal.MXID) == 0 {
portal.log.Warnln("handleMessage called even though portal.MXID is empty")
@ -304,24 +341,35 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
msgType := portal.getMessageType(evt.Message)
if msgType == "ignore" {
return
} else if portal.isRecentlyHandled(msgID) {
} else if portal.isRecentlyHandled(msgID, false) {
portal.log.Debugfln("Not handling %s (%s): message was recently handled", msgID, msgType)
return
} else if portal.isDuplicate(msgID) {
}
existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID)
if existingMsg != nil {
if existingMsg.DecryptionError {
portal.log.Debugfln("Got decryptable version of previously undecryptable message %s (%s)", msgID, msgType)
} else {
portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType)
return
}
}
intent := portal.getMessageIntent(source, &evt.Info)
converted := portal.convertMessage(intent, source, &evt.Info, evt.Message)
if converted != nil {
var eventID id.EventID
if existingMsg != nil {
converted.Content.SetEdit(existingMsg.MXID)
}
resp, err := portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli())
if err != nil {
portal.log.Errorln("Failed to send %s to Matrix: %v", msgID, err)
} else {
eventID = resp.EventID
}
if converted.Caption != nil {
// TODO figure out how to handle captions with undecryptable messages turning decryptable
if converted.Caption != nil && existingMsg == nil {
resp, err = portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli())
if err != nil {
portal.log.Errorln("Failed to send caption of %s to Matrix: %v", msgID, err)
@ -330,55 +378,65 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
}
}
if len(eventID) != 0 {
portal.finishHandling(&evt.Info, resp.EventID)
portal.finishHandling(existingMsg, &evt.Info, resp.EventID, false)
}
} else if msgType == "revoke" {
portal.HandleMessageRevoke(source, evt.Message.GetProtocolMessage().GetKey())
if existingMsg != nil {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
Reason: "The undecryptable message was actually the deletion of another message",
})
existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false)
}
} else {
portal.log.Warnln("Unhandled message:", evt.Info, evt.Message)
if existingMsg != nil {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{
Reason: "The undecryptable message contained an unsupported message type",
})
existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false)
}
return
}
portal.bridge.Metrics.TrackWhatsAppMessage(evt.Info.Timestamp, strings.Split(msgType, " ")[0])
}
func (portal *Portal) isRecentlyHandled(id types.MessageID) bool {
func (portal *Portal) isRecentlyHandled(id types.MessageID, decryptionError bool) bool {
start := portal.recentlyHandledIndex
lookingForMsg := recentlyHandledWrapper{id, decryptionError}
for i := start; i != start; i = (i - 1) % recentlyHandledLength {
if portal.recentlyHandled[i] == id {
if portal.recentlyHandled[i] == lookingForMsg {
return true
}
}
return false
}
func (portal *Portal) isDuplicate(id types.MessageID) bool {
msg := portal.bridge.DB.Message.GetByJID(portal.Key, id)
if msg != nil {
return true
}
return false
}
func init() {
gob.Register(&waProto.Message{})
}
func (portal *Portal) markHandled(info *types.MessageInfo, mxid id.EventID, isSent, recent bool) *database.Message {
msg := portal.bridge.DB.Message.New()
func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent, decryptionError bool) *database.Message {
if msg == nil {
msg = portal.bridge.DB.Message.New()
msg.Chat = portal.Key
msg.JID = info.ID
msg.MXID = mxid
msg.Timestamp = info.Timestamp
msg.Sender = info.Sender
msg.Sent = isSent
msg.DecryptionError = decryptionError
msg.Insert()
} else {
msg.UpdateMXID(mxid, decryptionError)
}
if recent {
portal.recentlyHandledLock.Lock()
index := portal.recentlyHandledIndex
portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength
portal.recentlyHandledLock.Unlock()
portal.recentlyHandled[index] = msg.JID
portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, decryptionError}
}
return msg
}
@ -406,10 +464,14 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app
return puppet.IntentFor(portal)
}
func (portal *Portal) finishHandling(message *types.MessageInfo, mxid id.EventID) {
portal.markHandled(message, mxid, true, true)
func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, decryptionError bool) {
portal.markHandled(existing, message, mxid, true, true, decryptionError)
portal.sendDeliveryReceipt(mxid)
if !decryptionError {
portal.log.Debugln("Handled message", message.ID, "->", mxid)
} else {
portal.log.Debugln("Handled message", message.ID, "->", mxid, "(undecryptable message error notice)")
}
}
func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) {
@ -896,13 +958,13 @@ func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*types.MessageI
} else if info, ok := infoMap[types.MessageID(msgID)]; !ok {
portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID)
} else {
portal.markHandled(info, eventID, true, false)
portal.markHandled(nil, info, eventID, true, false, false)
}
}
} else {
for i := 0; i < len(infos); i++ {
if infos[i] != nil {
portal.markHandled(infos[i], eventIDs[i], true, false)
portal.markHandled(nil, infos[i], eventIDs[i], true, false, false)
}
}
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
@ -2358,7 +2420,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
return
}
info := portal.generateMessageInfo(sender)
dbMsg := portal.markHandled(info, evt.ID, false, true)
dbMsg := portal.markHandled(nil, info, evt.ID, false, true, false)
portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg)
if err != nil {

View file

@ -436,7 +436,10 @@ func (user *User) HandleEvent(event interface{}) {
go user.handleReceipt(v)
case *events.Message:
portal := user.GetPortalByJID(v.Info.Chat)
portal.messages <- PortalMessage{v, user}
portal.messages <- PortalMessage{evt: v, source: user}
case *events.UndecryptableMessage:
portal := user.GetPortalByJID(v.Info.Chat)
portal.messages <- PortalMessage{undecryptable: v, source: user}
case *events.HistorySync:
user.historySyncs <- v
case *events.Mute:
@ -458,6 +461,8 @@ func (user *User) HandleEvent(event interface{}) {
if portal != nil {
go user.updateChatTag(nil, portal, user.bridge.Config.Bridge.PinnedTag, v.Action.GetPinned())
}
case *events.AppState:
// Ignore
default:
user.log.Debugfln("Unknown type of event in HandleEvent: %T", v)
}