Add support for creating polls from Matrix

This commit is contained in:
Tulir Asokan 2022-12-23 15:17:57 +02:00
parent 52876bb607
commit 0305680317
10 changed files with 374 additions and 73 deletions

View File

@ -1,3 +1,7 @@
# unreleased
* Added support for sending polls from Matrix to WhatsApp.
# v0.8.0 (2022-12-16) # v0.8.0 (2022-12-16)
* Added support for bridging polls from WhatsApp and votes in both directions. * Added support for bridging polls from WhatsApp and votes in both directions.

View File

@ -6,7 +6,7 @@
* [x] Location messages * [x] Location messages
* [x] Media/files * [x] Media/files
* [x] Replies * [x] Replies
* [ ] Polls * [x] Polls
* [x] Poll votes * [x] Poll votes
* [x] Message redactions * [x] Message redactions
* [x] Reactions * [x] Reactions

View File

@ -134,11 +134,12 @@ const (
type MessageType string type MessageType string
const ( const (
MsgUnknown MessageType = "" MsgUnknown MessageType = ""
MsgFake MessageType = "fake" MsgFake MessageType = "fake"
MsgNormal MessageType = "message" MsgNormal MessageType = "message"
MsgReaction MessageType = "reaction" MsgReaction MessageType = "reaction"
MsgEdit MessageType = "edit" MsgEdit MessageType = "edit"
MsgMatrixPoll MessageType = "matrix-poll"
) )
type Message struct { type Message struct {

118
database/polloption.go Normal file
View File

@ -0,0 +1,118 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"fmt"
"strings"
"github.com/lib/pq"
"maunium.net/go/mautrix/util/dbutil"
)
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
var hash []byte
err = rows.Scan(&id, &hash)
if err != nil {
// return below
} else if len(hash) != 32 {
err = fmt.Errorf("unexpected hash length %d", len(hash))
} else {
hashArr = *(*[32]byte)(hash)
}
return
}
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
args := make([]any, len(opts)*2+1)
placeholders := make([]string, len(opts))
args[0] = msg.MXID
i := 0
for hash, id := range opts {
args[i*2+1] = id
hashCopy := hash
args[i*2+2] = hashCopy[:]
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
i++
}
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
_, err := msg.db.Exec(query, args...)
if err != nil {
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
}
}
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(hashes)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
args = make([]any, len(hashes)+1)
args[0] = msg.MXID
for i, hash := range hashes {
args[i+1] = hash
}
}
ids := make(map[[32]byte]string, len(hashes))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
break
}
ids[hash] = id
}
}
return ids
}
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(ids)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
args = make([]any, len(ids)+1)
args[0] = msg.MXID
for i, id := range ids {
args[i+1] = id
}
}
hashes := make(map[string][32]byte, len(ids))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
break
}
hashes[id] = hash
}
}
return hashes
}

View File

@ -1,4 +1,4 @@
-- v0 -> v52: Latest revision -- v0 -> v54: Latest revision
CREATE TABLE "user" ( CREATE TABLE "user" (
mxid TEXT PRIMARY KEY, mxid TEXT PRIMARY KEY,
@ -40,6 +40,7 @@ CREATE TABLE portal (
PRIMARY KEY (jid, receiver) PRIMARY KEY (jid, receiver)
); );
CREATE INDEX portal_parent_group_idx ON portal(parent_group);
CREATE TABLE puppet ( CREATE TABLE puppet (
username TEXT PRIMARY KEY, username TEXT PRIMARY KEY,
@ -79,6 +80,16 @@ CREATE TABLE message (
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
); );
CREATE TABLE poll_option_id (
msg_mxid TEXT,
opt_id TEXT,
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
PRIMARY KEY (msg_mxid, opt_id),
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE reaction ( CREATE TABLE reaction (
chat_jid TEXT, chat_jid TEXT,
chat_receiver TEXT, chat_receiver TEXT,

View File

@ -0,0 +1,11 @@
-- v54: Store mapping for poll option IDs from Matrix
CREATE TABLE poll_option_id (
msg_mxid TEXT,
opt_id TEXT,
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
PRIMARY KEY (msg_mxid, opt_id),
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
);

View File

@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)
var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```") var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`) var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids" const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids"
const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions"
type Formatter struct { type Formatter struct {
bridge *WABridge bridge *WABridge
@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter {
Newline: "\n", Newline: "\n",
PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
if mxid[0] == '@' { _, disableMentions := ctx[disableMentionsContextKey]
if mxid[0] == '@' && !disableMentions {
puppet := bridge.GetPuppetByMXID(id.UserID(mxid)) puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
if puppet != nil { if puppet != nil {
jids, ok := ctx[mentionedJIDsContextKey].([]string) jids, ok := ctx[mentionedJIDsContextKey].([]string)
@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter {
return "@" + puppet.JID.User return "@" + puppet.JID.User
} }
} }
return mxid return displayname
}, },
BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) }, BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) }, ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) {
mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string) mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
return result, mentionedJIDs return result, mentionedJIDs
} }
func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string {
ctx := make(format.Context)
ctx[disableMentionsContextKey] = true
return formatter.matrixHTMLParser.Parse(html, ctx)
}

View File

@ -87,8 +87,9 @@ func (br *WABridge) Init() {
// TODO this is a weird place for this // TODO this is a weird place for this
br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence) br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage) br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage) br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
Segment.log = br.Log.Sub("Segment") Segment.log = br.Log.Sub("Segment")
Segment.key = br.Config.SegmentKey Segment.key = br.Config.SegmentKey

View File

@ -52,6 +52,8 @@ var (
errTargetIsFake = errors.New("target is a fake event") errTargetIsFake = errors.New("target is a fake event")
errReactionSentBySomeoneElse = errors.New("target reaction was sent by someone else") errReactionSentBySomeoneElse = errors.New("target reaction was sent by someone else")
errDMSentByOtherUser = errors.New("target message was sent by the other user in a DM") errDMSentByOtherUser = errors.New("target message was sent by the other user in a DM")
errPollMissingQuestion = errors.New("poll message is missing question")
errPollDuplicateOption = errors.New("poll options must be unique")
errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported") errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
errBroadcastSendDisabled = errors.New("sending status messages is disabled") errBroadcastSendDisabled = errors.New("sending status messages is disabled")
@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, "" return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
case errors.Is(err, errMNoticeDisabled): case errors.Is(err, errMNoticeDisabled):
return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, "" return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
case errors.Is(err, errMediaUnsupportedType): case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption):
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error() return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
case errors.Is(err, errTimeoutBeforeHandling): case errors.Is(err, errTimeoutBeforeHandling):
return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled" return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
msgType = "reaction" msgType = "reaction"
case event.EventRedaction: case event.EventRedaction:
msgType = "redaction" msgType = "redaction"
case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse: case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response" msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default: default:
msgType = "unknown event" msgType = "unknown event"
} }

261
portal.go
View File

@ -19,6 +19,7 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
portal.handleMatrixReadReceipt(msg.user, "", evtTS, false) portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
timings.implicitRR = time.Since(implicitRRStart) timings.implicitRR = time.Since(implicitRRStart)
switch msg.evt.Type { switch msg.evt.Type {
case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse: case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart:
portal.HandleMatrixMessage(msg.user, msg.evt, timings) portal.HandleMatrixMessage(msg.user, msg.evt, timings)
case event.EventRedaction: case event.EventRedaction:
portal.HandleMatrixRedaction(msg.user, msg.evt) portal.HandleMatrixRedaction(msg.user, msg.evt)
@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m
} }
func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage { func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
if !portal.bridge.Config.Bridge.ExtEvPolls {
return nil
}
pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId()) pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
if pollMessage == nil { if pollMessage == nil {
portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId()) portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
return nil return nil
} }
selectedHashes := make([]string, len(vote.GetSelectedOptions())) selectedHashes := make([]string, len(vote.GetSelectedOptions()))
for i, opt := range vote.GetSelectedOptions() { if pollMessage.Type == database.MsgMatrixPoll {
selectedHashes[i] = hex.EncodeToString(opt) mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions())
for i, opt := range vote.GetSelectedOptions() {
if len(opt) != 32 {
portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID)
continue
}
var ok bool
selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)]
if !ok {
portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID)
}
}
} else {
for i, opt := range vote.GetSelectedOptions() {
selectedHashes[i] = hex.EncodeToString(opt)
}
} }
evtType := TypeMSC3881PollResponse evtType := TypeMSC3381PollResponse
//if portal.bridge.Config.Bridge.ExtEvPolls == 2 { //if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
// evtType = TypeMSC3881V2PollResponse // evtType = TypeMSC3381V2PollResponse
//} //}
return &ConvertedMessage{ return &ConvertedMessage{
Intent: intent, Intent: intent,
@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m
} }
evtType := event.EventMessage evtType := event.EventMessage
if portal.bridge.Config.Bridge.ExtEvPolls { if portal.bridge.Config.Bridge.ExtEvPolls {
evtType.Type = "org.matrix.msc3381.poll.start" evtType = TypeMSC3381PollStart
} }
//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 { //else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
// evtType.Type = "org.matrix.msc3381.v2.poll.start" // evtType.Type = "org.matrix.msc3381.v2.poll.start"
@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
} }
var ( var (
TypeMSC3881PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"} TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"}
) )
type PollResponseContent struct { type PollResponseContent struct {
@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = *rel content.RelatesTo = *rel
} }
func init() { type MSC1767Message struct {
event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{}) Text string `json:"org.matrix.msc1767.text,omitempty"`
event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{}) HTML string `json:"org.matrix.msc1767.html,omitempty"`
Message []struct {
MimeType string `json:"mimetype"`
Body string `json:"body"`
} `json:"org.matrix.msc1767.message,omitempty"`
} }
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) {
for _, part := range msg.Message {
if part.MimeType == "text/html" && msg.HTML == "" {
msg.HTML = part.Body
} else if part.MimeType == "text/plain" && msg.Text == "" {
msg.Text = part.Body
}
}
if msg.HTML != "" {
if mentions {
return portal.bridge.Formatter.ParseMatrix(msg.HTML)
} else {
return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil
}
}
return msg.Text, nil
}
type PollStartContent struct {
RelatesTo *event.RelatesTo `json:"m.relates_to"`
PollStart struct {
Kind string `json:"kind"`
MaxSelections int `json:"max_selections"`
Question MSC1767Message `json:"question"`
Answers []struct {
ID string `json:"id"`
MSC1767Message
} `json:"answers"`
} `json:"org.matrix.msc3381.poll.start"`
}
func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
if content.RelatesTo == nil {
content.RelatesTo = &event.RelatesTo{}
}
return content.RelatesTo
}
func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
return content.RelatesTo
}
func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = rel
}
func init() {
event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
}
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
content, ok := evt.Content.Parsed.(*PollResponseContent) content, ok := evt.Content.Parsed.(*PollResponseContent)
if !ok { if !ok {
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
} }
var answers []string var answers []string
if content.V1Response.Answers != nil { if content.V1Response.Answers != nil {
@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
} }
pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID) pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
if pollMsg == nil { if pollMsg == nil {
return nil, sender, errTargetNotFound return nil, sender, nil, errTargetNotFound
} }
pollMsgInfo := &types.MessageInfo{ pollMsgInfo := &types.MessageInfo{
MessageSource: types.MessageSource{ MessageSource: types.MessageSource{
@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
Type: "poll", Type: "poll",
} }
optionHashes := make([][]byte, 0, len(answers)) optionHashes := make([][]byte, 0, len(answers))
for _, selection := range answers { if pollMsg.Type == database.MsgMatrixPoll {
hash, _ := hex.DecodeString(selection) mappedAnswers := pollMsg.GetPollOptionHashes(answers)
if hash != nil && len(hash) == 32 { for _, selection := range answers {
optionHashes = append(optionHashes, hash) hash, ok := mappedAnswers[selection]
if ok {
optionHashes = append(optionHashes, hash[:])
} else {
portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID)
}
}
} else {
for _, selection := range answers {
hash, _ := hex.DecodeString(selection)
if hash != nil && len(hash) == 32 {
optionHashes = append(optionHashes, hash)
}
} }
} }
pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{ pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
SelectedOptions: optionHashes, SelectedOptions: optionHashes,
}) })
return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err
} }
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse { content, ok := evt.Content.Parsed.(*PollStartContent)
return portal.convertMatrixPollVote(ctx, sender, evt)
}
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok { if !ok {
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
} }
var editRootMsg *database.Message maxAnswers := content.PollStart.MaxSelections
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits { if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID) maxAnswers = 0
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User { }
return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message fmt.Printf("%+v\n", content.PollStart)
ctxInfo := portal.generateContextInfo(content.RelatesTo)
var question string
question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true)
if len(question) == 0 {
return nil, sender, nil, errPollMissingQuestion
}
options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers))
optionMap := make(map[[32]byte]string, len(options))
for i, opt := range content.PollStart.Answers {
body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false)
hash := sha256.Sum256([]byte(body))
if _, alreadyExists := optionMap[hash]; alreadyExists {
portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body)
return nil, sender, nil, errPollDuplicateOption
} }
if content.NewContent != nil { optionMap[hash] = opt.ID
content = content.NewContent options[i] = &waProto.PollCreationMessage_Option{
OptionName: proto.String(body),
} }
} }
secret := make([]byte, 32)
_, err := rand.Read(secret)
return &waProto.Message{
PollCreationMessage: &waProto.PollCreationMessage{
Name: proto.String(question),
Options: options,
SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)),
ContextInfo: ctxInfo,
},
MessageContextInfo: &waProto.MessageContextInfo{
MessageSecret: secret,
},
}, sender, &extraConvertMeta{PollOptions: optionMap}, err
}
msg := &waProto.Message{} func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo {
var ctxInfo waProto.ContextInfo var ctxInfo waProto.ContextInfo
replyToID := content.RelatesTo.GetReplyTo() replyToID := relatesTo.GetReplyTo()
if len(replyToID) > 0 { if len(replyToID) > 0 {
replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID) replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal { if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) {
ctxInfo.StanzaId = &replyToMsg.JID ctxInfo.StanzaId = &replyToMsg.JID
ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String()) ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
// Using blank content here seems to work fine on all official WhatsApp apps. // Using blank content here seems to work fine on all official WhatsApp apps.
@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if portal.ExpirationTime != 0 { if portal.ExpirationTime != 0 {
ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime) ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
} }
return &ctxInfo
}
type extraConvertMeta struct {
PollOptions map[[32]byte]string
}
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse {
return portal.convertMatrixPollVote(ctx, sender, evt)
} else if evt.Type == TypeMSC3381PollStart {
return portal.convertMatrixPollStart(ctx, sender, evt)
}
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
}
var editRootMsg *database.Message
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message
}
if content.NewContent != nil {
content = content.NewContent
}
}
msg := &waProto.Message{}
ctxInfo := portal.generateContextInfo(content.RelatesTo)
relaybotFormatted := false relaybotFormatted := false
if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) { if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
if !portal.HasRelaybot() { if !portal.HasRelaybot() {
return nil, sender, errUserNotLoggedIn return nil, sender, nil, errUserNotLoggedIn
} }
relaybotFormatted = portal.addRelaybotFormat(sender, content) relaybotFormatted = portal.addRelaybotFormat(sender, content)
sender = portal.GetRelayUser() sender = portal.GetRelayUser()
@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgText, event.MsgEmote, event.MsgNotice: case event.MsgText, event.MsgEmote, event.MsgNotice:
text := content.Body text := content.Body
if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices { if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
return nil, sender, errMNoticeDisabled return nil, sender, nil, errMNoticeDisabled
} }
if content.Format == event.FormatHTML { if content.Format == event.FormatHTML {
text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody) text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
} }
msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{ msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
Text: &text, Text: &text,
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
} }
hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage) hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
if ctx.Err() != nil { if ctx.Err() != nil {
return nil, nil, ctx.Err() return nil, nil, nil, ctx.Err()
} }
if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview { if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
// No need for extended message // No need for extended message
@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgImage: case event.MsgImage:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage) media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
if media == nil { if media == nil {
return nil, sender, err return nil, sender, nil, err
} }
ctxInfo.MentionedJid = media.MentionedJIDs ctxInfo.MentionedJid = media.MentionedJIDs
msg.ImageMessage = &waProto.ImageMessage{ msg.ImageMessage = &waProto.ImageMessage{
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
Caption: &media.Caption, Caption: &media.Caption,
JpegThumbnail: media.Thumbnail, JpegThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MessageType(event.EventSticker.Type): case event.MessageType(event.EventSticker.Type):
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage) media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
if media == nil { if media == nil {
return nil, sender, err return nil, sender, nil, err
} }
ctxInfo.MentionedJid = media.MentionedJIDs ctxInfo.MentionedJid = media.MentionedJIDs
msg.StickerMessage = &waProto.StickerMessage{ msg.StickerMessage = &waProto.StickerMessage{
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
PngThumbnail: media.Thumbnail, PngThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
gifPlayback := content.GetInfo().MimeType == "image/gif" gifPlayback := content.GetInfo().MimeType == "image/gif"
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo) media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
if media == nil { if media == nil {
return nil, sender, err return nil, sender, nil, err
} }
duration := uint32(content.GetInfo().Duration / 1000) duration := uint32(content.GetInfo().Duration / 1000)
ctxInfo.MentionedJid = media.MentionedJIDs ctxInfo.MentionedJid = media.MentionedJIDs
msg.VideoMessage = &waProto.VideoMessage{ msg.VideoMessage = &waProto.VideoMessage{
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
Caption: &media.Caption, Caption: &media.Caption,
JpegThumbnail: media.Thumbnail, JpegThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgAudio: case event.MsgAudio:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio) media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
if media == nil { if media == nil {
return nil, sender, err return nil, sender, nil, err
} }
duration := uint32(content.GetInfo().Duration / 1000) duration := uint32(content.GetInfo().Duration / 1000)
msg.AudioMessage = &waProto.AudioMessage{ msg.AudioMessage = &waProto.AudioMessage{
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
Url: &media.URL, Url: &media.URL,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
Mimetype: &content.GetInfo().MimeType, Mimetype: &content.GetInfo().MimeType,
@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgFile: case event.MsgFile:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument) media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
if media == nil { if media == nil {
return nil, sender, err return nil, sender, nil, err
} }
msg.DocumentMessage = &waProto.DocumentMessage{ msg.DocumentMessage = &waProto.DocumentMessage{
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
Caption: &media.Caption, Caption: &media.Caption,
JpegThumbnail: media.Thumbnail, JpegThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgLocation: case event.MsgLocation:
lat, long, err := parseGeoURI(content.GeoURI) lat, long, err := parseGeoURI(content.GeoURI)
if err != nil { if err != nil {
return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err) return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
} }
msg.LocationMessage = &waProto.LocationMessage{ msg.LocationMessage = &waProto.LocationMessage{
DegreesLatitude: &lat, DegreesLatitude: &lat,
DegreesLongitude: &long, DegreesLongitude: &long,
Comment: &content.Body, Comment: &content.Body,
ContextInfo: &ctxInfo, ContextInfo: ctxInfo,
} }
default: default:
return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType) return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
} }
if editRootMsg != nil { if editRootMsg != nil {
@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
} }
} }
return msg, sender, nil return msg, sender, nil, nil
} }
func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo { func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
start := time.Now() start := time.Now()
ms := metricSender{portal: portal, timings: &timings} ms := metricSender{portal: portal, timings: &timings}
allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart
if err := portal.canBridgeFrom(sender, allowRelay); err != nil { if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
go ms.sendMessageMetrics(evt, err, "Ignoring", true) go ms.sendMessageMetrics(evt, err, "Ignoring", true)
return return
@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
timings.preproc = time.Since(start) timings.preproc = time.Since(start)
start = time.Now() start = time.Now()
msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt) msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt)
timings.convert = time.Since(start) timings.convert = time.Since(start)
if msg == nil { if msg == nil {
go ms.sendMessageMetrics(evt, err, "Error converting", true) go ms.sendMessageMetrics(evt, err, "Error converting", true)
return return
} }
dbMsgType := database.MsgNormal dbMsgType := database.MsgNormal
if msg.EditedMessage == nil { if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil {
dbMsgType = database.MsgMatrixPoll
} else if msg.EditedMessage == nil {
portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true) portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
} else { } else {
dbMsgType = database.MsgEdit dbMsgType = database.MsgEdit
@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
} else { } else {
info.ID = dbMsg.JID info.ID = dbMsg.JID
} }
if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil {
dbMsg.PutPollOptions(extraMeta.PollOptions)
}
portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID) portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
start = time.Now() start = time.Now()
resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg) resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)