diff --git a/database/message.go b/database/message.go index 53488fd..2ec7e95 100644 --- a/database/message.go +++ b/database/message.go @@ -40,12 +40,34 @@ func (mq *MessageQuery) New() *Message { } } +const ( + getAllMessagesQuery = ` + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + WHERE chat_jid=$1 AND chat_receiver=$2 + ` + getMessageByJIDQuery = ` + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3 + ` + getMessageByMXIDQuery = ` + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + WHERE mxid=$1 + ` + getLastMessageInChatQuery = ` + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1 + ` + getFirstMessageInChatQuery = ` + SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error FROM message + WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1 + ` +) + func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { - rows, err := mq.db.Query("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver) + rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver) if err != nil || rows == nil { return nil } - defer rows.Close() for rows.Next() { messages = append(messages, mq.New().Scan(rows)) } @@ -53,23 +75,19 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { } func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message { - return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+ - "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid) + return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid)) } func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message { - return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+ - "FROM message WHERE mxid=$1", mxid) + return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid)) } func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message { - return mq.GetLastInChatBefore(chat, time.Now().Add(60 * time.Second)) + return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second)) } func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message { - msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+ - "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1", - chat.JID, chat.Receiver, maxTimestamp.Unix()) + msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())) if msg == nil || msg.Timestamp.IsZero() { // Old db, we don't know what the last message is. return nil @@ -78,13 +96,10 @@ func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Ti } func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message { - return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent "+ - "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1", - chat.JID, chat.Receiver) + return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver)) } -func (mq *MessageQuery) get(query string, args ...interface{}) *Message { - row := mq.db.QueryRow(query, args...) +func (mq *MessageQuery) maybeScan(row *sql.Row) *Message { if row == nil { return nil } @@ -101,6 +116,8 @@ type Message struct { Sender types.JID Timestamp time.Time Sent bool + + DecryptionError bool } func (msg *Message) IsFakeMXID() bool { @@ -109,7 +126,7 @@ func (msg *Message) IsFakeMXID() bool { func (msg *Message) Scan(row Scannable) *Message { var ts int64 - err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent) + err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.DecryptionError) if err != nil { if err != sql.ErrNoRows { msg.log.Errorln("Database scan failed:", err) @@ -129,9 +146,9 @@ func (msg *Message) Insert() { sender = "" } _, err := msg.db.Exec(`INSERT INTO message - (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent) - VALUES ($1, $2, $3, $4, $5, $6, $7)`, - msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent) + (chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, decryption_error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.DecryptionError) if err != nil { msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) } @@ -145,6 +162,15 @@ func (msg *Message) MarkSent() { } } +func (msg *Message) UpdateMXID(mxid id.EventID, stillDecryptionError bool) { + msg.MXID = mxid + msg.DecryptionError = stillDecryptionError + _, err := msg.db.Exec("UPDATE message SET mxid=$4, decryption_error=$5 WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, stillDecryptionError) + if err != nil { + msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err) + } +} + func (msg *Message) Delete() { _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID) if err != nil { diff --git a/database/upgrades/2021-10-27-message-decryption-errors.go b/database/upgrades/2021-10-27-message-decryption-errors.go new file mode 100644 index 0000000..288709e --- /dev/null +++ b/database/upgrades/2021-10-27-message-decryption-errors.go @@ -0,0 +1,12 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error { + _, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`) + return err + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index e28a5b7..13949c6 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -39,7 +39,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 27 +const NumberOfUpgrades = 28 var upgrades [NumberOfUpgrades]upgrade diff --git a/portal.go b/portal.go index 93ce210..8d2ea1f 100644 --- a/portal.go +++ b/portal.go @@ -42,6 +42,8 @@ import ( "golang.org/x/image/webp" "google.golang.org/protobuf/proto" + "maunium.net/go/mautrix/format" + "go.mau.fi/whatsmeow" waProto "go.mau.fi/whatsmeow/binary/proto" "go.mau.fi/whatsmeow/types" @@ -160,8 +162,14 @@ func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal { const recentlyHandledLength = 100 type PortalMessage struct { - evt *events.Message - source *User + evt *events.Message + undecryptable *events.UndecryptableMessage + source *User +} + +type recentlyHandledWrapper struct { + id types.MessageID + err bool } type Portal struct { @@ -174,7 +182,7 @@ type Portal struct { encryptLock sync.Mutex backfillLock sync.Mutex - recentlyHandled [recentlyHandledLength]types.MessageID + recentlyHandled [recentlyHandledLength]recentlyHandledWrapper recentlyHandledLock sync.Mutex recentlyHandledIndex uint8 @@ -185,8 +193,6 @@ type Portal struct { hasRelaybot *bool } -const MaxMessageAgeToCreatePortal = 5 * 60 // 5 minutes - func (portal *Portal) syncDoublePuppetDetailsAfterCreate(source *User) { doublePuppet := portal.bridge.GetPuppetByCustomMXID(source.MXID) if doublePuppet == nil { @@ -210,13 +216,20 @@ func (portal *Portal) handleMessageLoop() { } portal.syncDoublePuppetDetailsAfterCreate(msg.source) } - //portal.backfillLock.Lock() - portal.handleMessage(msg.source, msg.evt) - //portal.backfillLock.Unlock() + if msg.evt != nil { + portal.handleMessage(msg.source, msg.evt) + } else if msg.undecryptable != nil { + portal.handleUndecryptableMessage(msg.source, msg.undecryptable) + } else { + portal.log.Warnln("Unexpected PortalMessage with no message: %+v", msg) + } } } func (portal *Portal) shouldCreateRoom(msg PortalMessage) bool { + if msg.undecryptable != nil { + return true + } waMsg := msg.evt.Message supportedMessages := []interface{}{ waMsg.Conversation, @@ -295,6 +308,30 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User, } } +const UndecryptableMessage = "Decrypting message from WhatsApp failed, waiting for sender to re-send... " + + "([learn more](https://faq.whatsapp.com/general/security-and-privacy/seeing-waiting-for-this-message-this-may-take-a-while))" + +func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.UndecryptableMessage) { + if len(portal.MXID) == 0 { + portal.log.Warnln("handleUndecryptableMessage called even though portal.MXID is empty") + return + } else if portal.isRecentlyHandled(evt.Info.ID, true) { + portal.log.Debugfln("Not handling %s (undecryptable): message was recently handled", evt.Info.ID) + return + } else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, evt.Info.ID); existingMsg != nil { + portal.log.Debugfln("Not handling %s (undecryptable): message is duplicate", evt.Info.ID) + return + } + intent := portal.getMessageIntent(source, &evt.Info) + content := format.RenderMarkdown(UndecryptableMessage, true, false) + content.MsgType = event.MsgNotice + resp, err := portal.sendMessage(intent, event.EventMessage, &content, evt.Info.Timestamp.UnixMilli()) + if err != nil { + portal.log.Errorln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err) + } + portal.finishHandling(nil, &evt.Info, resp.EventID, true) +} + func (portal *Portal) handleMessage(source *User, evt *events.Message) { if len(portal.MXID) == 0 { portal.log.Warnln("handleMessage called even though portal.MXID is empty") @@ -304,24 +341,35 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { msgType := portal.getMessageType(evt.Message) if msgType == "ignore" { return - } else if portal.isRecentlyHandled(msgID) { + } else if portal.isRecentlyHandled(msgID, false) { portal.log.Debugfln("Not handling %s (%s): message was recently handled", msgID, msgType) return - } else if portal.isDuplicate(msgID) { - portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType) - return } + existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID) + if existingMsg != nil { + if existingMsg.DecryptionError { + portal.log.Debugfln("Got decryptable version of previously undecryptable message %s (%s)", msgID, msgType) + } else { + portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType) + return + } + } + intent := portal.getMessageIntent(source, &evt.Info) converted := portal.convertMessage(intent, source, &evt.Info, evt.Message) if converted != nil { var eventID id.EventID + if existingMsg != nil { + converted.Content.SetEdit(existingMsg.MXID) + } resp, err := portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli()) if err != nil { portal.log.Errorln("Failed to send %s to Matrix: %v", msgID, err) } else { eventID = resp.EventID } - if converted.Caption != nil { + // TODO figure out how to handle captions with undecryptable messages turning decryptable + if converted.Caption != nil && existingMsg == nil { resp, err = portal.sendMessage(converted.Intent, converted.Type, converted.Content, evt.Info.Timestamp.UnixMilli()) if err != nil { portal.log.Errorln("Failed to send caption of %s to Matrix: %v", msgID, err) @@ -330,55 +378,65 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) { } } if len(eventID) != 0 { - portal.finishHandling(&evt.Info, resp.EventID) + portal.finishHandling(existingMsg, &evt.Info, resp.EventID, false) } } else if msgType == "revoke" { portal.HandleMessageRevoke(source, evt.Message.GetProtocolMessage().GetKey()) + if existingMsg != nil { + _, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ + Reason: "The undecryptable message was actually the deletion of another message", + }) + existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false) + } } else { portal.log.Warnln("Unhandled message:", evt.Info, evt.Message) + if existingMsg != nil { + _, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ + Reason: "The undecryptable message contained an unsupported message type", + }) + existingMsg.UpdateMXID("net.maunium.whatsapp.fake::" + existingMsg.MXID, false) + } return } portal.bridge.Metrics.TrackWhatsAppMessage(evt.Info.Timestamp, strings.Split(msgType, " ")[0]) } -func (portal *Portal) isRecentlyHandled(id types.MessageID) bool { +func (portal *Portal) isRecentlyHandled(id types.MessageID, decryptionError bool) bool { start := portal.recentlyHandledIndex + lookingForMsg := recentlyHandledWrapper{id, decryptionError} for i := start; i != start; i = (i - 1) % recentlyHandledLength { - if portal.recentlyHandled[i] == id { + if portal.recentlyHandled[i] == lookingForMsg { return true } } return false } -func (portal *Portal) isDuplicate(id types.MessageID) bool { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, id) - if msg != nil { - return true - } - return false -} - func init() { gob.Register(&waProto.Message{}) } -func (portal *Portal) markHandled(info *types.MessageInfo, mxid id.EventID, isSent, recent bool) *database.Message { - msg := portal.bridge.DB.Message.New() - msg.Chat = portal.Key - msg.JID = info.ID - msg.MXID = mxid - msg.Timestamp = info.Timestamp - msg.Sender = info.Sender - msg.Sent = isSent - msg.Insert() +func (portal *Portal) markHandled(msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent, decryptionError bool) *database.Message { + if msg == nil { + msg = portal.bridge.DB.Message.New() + msg.Chat = portal.Key + msg.JID = info.ID + msg.MXID = mxid + msg.Timestamp = info.Timestamp + msg.Sender = info.Sender + msg.Sent = isSent + msg.DecryptionError = decryptionError + msg.Insert() + } else { + msg.UpdateMXID(mxid, decryptionError) + } if recent { portal.recentlyHandledLock.Lock() index := portal.recentlyHandledIndex portal.recentlyHandledIndex = (portal.recentlyHandledIndex + 1) % recentlyHandledLength portal.recentlyHandledLock.Unlock() - portal.recentlyHandled[index] = msg.JID + portal.recentlyHandled[index] = recentlyHandledWrapper{msg.JID, decryptionError} } return msg } @@ -406,10 +464,14 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo) *app return puppet.IntentFor(portal) } -func (portal *Portal) finishHandling(message *types.MessageInfo, mxid id.EventID) { - portal.markHandled(message, mxid, true, true) +func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, decryptionError bool) { + portal.markHandled(existing, message, mxid, true, true, decryptionError) portal.sendDeliveryReceipt(mxid) - portal.log.Debugln("Handled message", message.ID, "->", mxid) + if !decryptionError { + portal.log.Debugln("Handled message", message.ID, "->", mxid) + } else { + portal.log.Debugln("Handled message", message.ID, "->", mxid, "(undecryptable message error notice)") + } } func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) { @@ -896,13 +958,13 @@ func (portal *Portal) finishBatch(eventIDs []id.EventID, infos []*types.MessageI } else if info, ok := infoMap[types.MessageID(msgID)]; !ok { portal.log.Warnfln("Didn't find info of message %s (event %s) to register it in the database", msgID, eventID) } else { - portal.markHandled(info, eventID, true, false) + portal.markHandled(nil, info, eventID, true, false, false) } } } else { for i := 0; i < len(infos); i++ { if infos[i] != nil { - portal.markHandled(infos[i], eventIDs[i], true, false) + portal.markHandled(nil, infos[i], eventIDs[i], true, false, false) } } portal.log.Infofln("Successfully sent %d events", len(eventIDs)) @@ -2358,7 +2420,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) { return } info := portal.generateMessageInfo(sender) - dbMsg := portal.markHandled(info, evt.ID, false, true) + dbMsg := portal.markHandled(nil, info, evt.ID, false, true, false) portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID) err := sender.Client.SendMessage(portal.Key.JID, info.ID, msg) if err != nil { diff --git a/user.go b/user.go index 01a5f16..b42ee15 100644 --- a/user.go +++ b/user.go @@ -436,7 +436,10 @@ func (user *User) HandleEvent(event interface{}) { go user.handleReceipt(v) case *events.Message: portal := user.GetPortalByJID(v.Info.Chat) - portal.messages <- PortalMessage{v, user} + portal.messages <- PortalMessage{evt: v, source: user} + case *events.UndecryptableMessage: + portal := user.GetPortalByJID(v.Info.Chat) + portal.messages <- PortalMessage{undecryptable: v, source: user} case *events.HistorySync: user.historySyncs <- v case *events.Mute: @@ -458,6 +461,8 @@ func (user *User) HandleEvent(event interface{}) { if portal != nil { go user.updateChatTag(nil, portal, user.bridge.Config.Bridge.PinnedTag, v.Action.GetPinned()) } + case *events.AppState: + // Ignore default: user.log.Debugfln("Unknown type of event in HandleEvent: %T", v) }