diff --git a/CHANGELOG.md b/CHANGELOG.md index e73a7c5..741d08a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# unreleased + +* Added support for sending polls from Matrix to WhatsApp. + # v0.8.0 (2022-12-16) * Added support for bridging polls from WhatsApp and votes in both directions. diff --git a/ROADMAP.md b/ROADMAP.md index ac83083..3106216 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -6,7 +6,7 @@ * [x] Location messages * [x] Media/files * [x] Replies - * [ ] Polls + * [x] Polls * [x] Poll votes * [x] Message redactions * [x] Reactions diff --git a/database/message.go b/database/message.go index 767294a..8e02bbd 100644 --- a/database/message.go +++ b/database/message.go @@ -134,11 +134,12 @@ const ( type MessageType string const ( - MsgUnknown MessageType = "" - MsgFake MessageType = "fake" - MsgNormal MessageType = "message" - MsgReaction MessageType = "reaction" - MsgEdit MessageType = "edit" + MsgUnknown MessageType = "" + MsgFake MessageType = "fake" + MsgNormal MessageType = "message" + MsgReaction MessageType = "reaction" + MsgEdit MessageType = "edit" + MsgMatrixPoll MessageType = "matrix-poll" ) type Message struct { diff --git a/database/polloption.go b/database/polloption.go new file mode 100644 index 0000000..aedcd53 --- /dev/null +++ b/database/polloption.go @@ -0,0 +1,118 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2022 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "fmt" + "strings" + + "github.com/lib/pq" + "maunium.net/go/mautrix/util/dbutil" +) + +func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) { + var hash []byte + err = rows.Scan(&id, &hash) + if err != nil { + // return below + } else if len(hash) != 32 { + err = fmt.Errorf("unexpected hash length %d", len(hash)) + } else { + hashArr = *(*[32]byte)(hash) + } + return +} + +func (msg *Message) PutPollOptions(opts map[[32]byte]string) { + query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)" + args := make([]any, len(opts)*2+1) + placeholders := make([]string, len(opts)) + args[0] = msg.MXID + i := 0 + for hash, id := range opts { + args[i*2+1] = id + hashCopy := hash + args[i*2+2] = hashCopy[:] + placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3) + i++ + } + query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ",")) + _, err := msg.db.Exec(query, args...) + if err != nil { + msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err) + } +} + +func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string { + query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)" + var args []any + if msg.db.Dialect == dbutil.Postgres { + args = []any{msg.MXID, pq.Array(hashes)} + } else { + query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ","))) + args = make([]any, len(hashes)+1) + args[0] = msg.MXID + for i, hash := range hashes { + args[i+1] = hash + } + } + ids := make(map[[32]byte]string, len(hashes)) + rows, err := msg.db.Query(query, args...) + if err != nil { + msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err) + } else { + for rows.Next() { + id, hash, err := scanPollOptionMapping(rows) + if err != nil { + msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err) + break + } + ids[hash] = id + } + } + return ids +} + +func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte { + query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)" + var args []any + if msg.db.Dialect == dbutil.Postgres { + args = []any{msg.MXID, pq.Array(ids)} + } else { + query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ","))) + args = make([]any, len(ids)+1) + args[0] = msg.MXID + for i, id := range ids { + args[i+1] = id + } + } + hashes := make(map[string][32]byte, len(ids)) + rows, err := msg.db.Query(query, args...) + if err != nil { + msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err) + } else { + for rows.Next() { + id, hash, err := scanPollOptionMapping(rows) + if err != nil { + msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err) + break + } + hashes[id] = hash + } + } + return hashes +} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 7f7a611..7612b92 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v52: Latest revision +-- v0 -> v54: Latest revision CREATE TABLE "user" ( mxid TEXT PRIMARY KEY, @@ -40,6 +40,7 @@ CREATE TABLE portal ( PRIMARY KEY (jid, receiver) ); +CREATE INDEX portal_parent_group_idx ON portal(parent_group); CREATE TABLE puppet ( username TEXT PRIMARY KEY, @@ -79,6 +80,16 @@ CREATE TABLE message ( FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ); +CREATE TABLE poll_option_id ( + msg_mxid TEXT, + opt_id TEXT, + opt_hash bytea CHECK ( length(opt_hash) = 32 ), + + PRIMARY KEY (msg_mxid, opt_id), + CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash), + CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE +); + CREATE TABLE reaction ( chat_jid TEXT, chat_receiver TEXT, diff --git a/database/upgrades/54-poll-option-id-map.sql b/database/upgrades/54-poll-option-id-map.sql new file mode 100644 index 0000000..2593fc3 --- /dev/null +++ b/database/upgrades/54-poll-option-id-map.sql @@ -0,0 +1,11 @@ +-- v54: Store mapping for poll option IDs from Matrix + +CREATE TABLE poll_option_id ( + msg_mxid TEXT, + opt_id TEXT, + opt_hash bytea CHECK ( length(opt_hash) = 32 ), + + PRIMARY KEY (msg_mxid, opt_id), + CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash), + CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE +); diff --git a/formatting.go b/formatting.go index 1418d79..3a0d870 100644 --- a/formatting.go +++ b/formatting.go @@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$) var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```") var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`) -const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids" +const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids" +const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions" type Formatter struct { bridge *WABridge @@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter { Newline: "\n", PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { - if mxid[0] == '@' { + _, disableMentions := ctx[disableMentionsContextKey] + if mxid[0] == '@' && !disableMentions { puppet := bridge.GetPuppetByMXID(id.UserID(mxid)) if puppet != nil { jids, ok := ctx[mentionedJIDsContextKey].([]string) @@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter { return "@" + puppet.JID.User } } - return mxid + return displayname }, BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) }, ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) }, @@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) { mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string) return result, mentionedJIDs } + +func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string { + ctx := make(format.Context) + ctx[disableMentionsContextKey] = true + return formatter.matrixHTMLParser.Parse(html, ctx) +} diff --git a/main.go b/main.go index 9fbb7c0..e4d5a67 100644 --- a/main.go +++ b/main.go @@ -87,8 +87,9 @@ func (br *WABridge) Init() { // TODO this is a weird place for this br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence) - br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage) - br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage) + br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage) + br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage) + br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage) Segment.log = br.Log.Sub("Segment") Segment.key = br.Config.SegmentKey diff --git a/messagetracking.go b/messagetracking.go index 72770fd..51bc2bd 100644 --- a/messagetracking.go +++ b/messagetracking.go @@ -52,6 +52,8 @@ var ( errTargetIsFake = errors.New("target is a fake event") errReactionSentBySomeoneElse = errors.New("target reaction was sent by someone else") errDMSentByOtherUser = errors.New("target message was sent by the other user in a DM") + errPollMissingQuestion = errors.New("poll message is missing question") + errPollDuplicateOption = errors.New("poll options must be unique") errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported") errBroadcastSendDisabled = errors.New("sending status messages is disabled") @@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, "" case errors.Is(err, errMNoticeDisabled): return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, "" - case errors.Is(err, errMediaUnsupportedType): + case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption): return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error() case errors.Is(err, errTimeoutBeforeHandling): return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled" @@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin msgType = "reaction" case event.EventRedaction: msgType = "redaction" - case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse: + case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse: msgType = "poll response" + case TypeMSC3381PollStart: + msgType = "poll start" default: msgType = "unknown event" } diff --git a/portal.go b/portal.go index 84db816..a871cf1 100644 --- a/portal.go +++ b/portal.go @@ -19,6 +19,7 @@ package main import ( "bytes" "context" + "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" @@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) { portal.handleMatrixReadReceipt(msg.user, "", evtTS, false) timings.implicitRR = time.Since(implicitRRStart) switch msg.evt.Type { - case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse: + case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart: portal.HandleMatrixMessage(msg.user, msg.evt, timings) case event.EventRedaction: portal.HandleMatrixRedaction(msg.user, msg.evt) @@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m } func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage { - if !portal.bridge.Config.Bridge.ExtEvPolls { - return nil - } pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId()) if pollMessage == nil { portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId()) @@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou return nil } selectedHashes := make([]string, len(vote.GetSelectedOptions())) - for i, opt := range vote.GetSelectedOptions() { - selectedHashes[i] = hex.EncodeToString(opt) + if pollMessage.Type == database.MsgMatrixPoll { + mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions()) + for i, opt := range vote.GetSelectedOptions() { + if len(opt) != 32 { + portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID) + continue + } + var ok bool + selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)] + if !ok { + portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID) + } + } + } else { + for i, opt := range vote.GetSelectedOptions() { + selectedHashes[i] = hex.EncodeToString(opt) + } } - evtType := TypeMSC3881PollResponse + evtType := TypeMSC3381PollResponse //if portal.bridge.Config.Bridge.ExtEvPolls == 2 { - // evtType = TypeMSC3881V2PollResponse + // evtType = TypeMSC3381V2PollResponse //} return &ConvertedMessage{ Intent: intent, @@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m } evtType := event.EventMessage if portal.bridge.Config.Bridge.ExtEvPolls { - evtType.Type = "org.matrix.msc3381.poll.start" + evtType = TypeMSC3381PollStart } //else if portal.bridge.Config.Bridge.ExtEvPolls == 2 { // evtType.Type = "org.matrix.msc3381.v2.poll.start" @@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte { } var ( - TypeMSC3881PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} - TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"} + TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"} + TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} + TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"} ) type PollResponseContent struct { @@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) { content.RelatesTo = *rel } -func init() { - event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{}) - event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{}) +type MSC1767Message struct { + Text string `json:"org.matrix.msc1767.text,omitempty"` + HTML string `json:"org.matrix.msc1767.html,omitempty"` + Message []struct { + MimeType string `json:"mimetype"` + Body string `json:"body"` + } `json:"org.matrix.msc1767.message,omitempty"` } -func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { +func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) { + for _, part := range msg.Message { + if part.MimeType == "text/html" && msg.HTML == "" { + msg.HTML = part.Body + } else if part.MimeType == "text/plain" && msg.Text == "" { + msg.Text = part.Body + } + } + if msg.HTML != "" { + if mentions { + return portal.bridge.Formatter.ParseMatrix(msg.HTML) + } else { + return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil + } + } + return msg.Text, nil +} + +type PollStartContent struct { + RelatesTo *event.RelatesTo `json:"m.relates_to"` + PollStart struct { + Kind string `json:"kind"` + MaxSelections int `json:"max_selections"` + Question MSC1767Message `json:"question"` + Answers []struct { + ID string `json:"id"` + MSC1767Message + } `json:"answers"` + } `json:"org.matrix.msc3381.poll.start"` +} + +func (content *PollStartContent) GetRelatesTo() *event.RelatesTo { + if content.RelatesTo == nil { + content.RelatesTo = &event.RelatesTo{} + } + return content.RelatesTo +} + +func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo { + return content.RelatesTo +} + +func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) { + content.RelatesTo = rel +} + +func init() { + event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{}) + event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{}) + event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{}) +} + +func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { content, ok := evt.Content.Parsed.(*PollResponseContent) if !ok { - return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) + return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) } var answers []string if content.V1Response.Answers != nil { @@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt } pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID) if pollMsg == nil { - return nil, sender, errTargetNotFound + return nil, sender, nil, errTargetNotFound } pollMsgInfo := &types.MessageInfo{ MessageSource: types.MessageSource{ @@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt Type: "poll", } optionHashes := make([][]byte, 0, len(answers)) - for _, selection := range answers { - hash, _ := hex.DecodeString(selection) - if hash != nil && len(hash) == 32 { - optionHashes = append(optionHashes, hash) + if pollMsg.Type == database.MsgMatrixPoll { + mappedAnswers := pollMsg.GetPollOptionHashes(answers) + for _, selection := range answers { + hash, ok := mappedAnswers[selection] + if ok { + optionHashes = append(optionHashes, hash[:]) + } else { + portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID) + } + } + } else { + for _, selection := range answers { + hash, _ := hex.DecodeString(selection) + if hash != nil && len(hash) == 32 { + optionHashes = append(optionHashes, hash) + } } } pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{ SelectedOptions: optionHashes, }) - return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err + return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err } -func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { - if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse { - return portal.convertMatrixPollVote(ctx, sender, evt) - } - content, ok := evt.Content.Parsed.(*event.MessageEventContent) +func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { + content, ok := evt.Content.Parsed.(*PollStartContent) if !ok { - return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) + return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) } - var editRootMsg *database.Message - if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits { - editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID) - if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User { - return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message + maxAnswers := content.PollStart.MaxSelections + if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 { + maxAnswers = 0 + } + fmt.Printf("%+v\n", content.PollStart) + ctxInfo := portal.generateContextInfo(content.RelatesTo) + var question string + question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true) + if len(question) == 0 { + return nil, sender, nil, errPollMissingQuestion + } + options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers)) + optionMap := make(map[[32]byte]string, len(options)) + for i, opt := range content.PollStart.Answers { + body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false) + hash := sha256.Sum256([]byte(body)) + if _, alreadyExists := optionMap[hash]; alreadyExists { + portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body) + return nil, sender, nil, errPollDuplicateOption } - if content.NewContent != nil { - content = content.NewContent + optionMap[hash] = opt.ID + options[i] = &waProto.PollCreationMessage_Option{ + OptionName: proto.String(body), } } + secret := make([]byte, 32) + _, err := rand.Read(secret) + return &waProto.Message{ + PollCreationMessage: &waProto.PollCreationMessage{ + Name: proto.String(question), + Options: options, + SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)), + ContextInfo: ctxInfo, + }, + MessageContextInfo: &waProto.MessageContextInfo{ + MessageSecret: secret, + }, + }, sender, &extraConvertMeta{PollOptions: optionMap}, err +} - msg := &waProto.Message{} +func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo { var ctxInfo waProto.ContextInfo - replyToID := content.RelatesTo.GetReplyTo() + replyToID := relatesTo.GetReplyTo() if len(replyToID) > 0 { replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID) - if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal { + if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) { ctxInfo.StanzaId = &replyToMsg.JID ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String()) // Using blank content here seems to work fine on all official WhatsApp apps. @@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if portal.ExpirationTime != 0 { ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime) } + return &ctxInfo +} + +type extraConvertMeta struct { + PollOptions map[[32]byte]string +} + +func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { + if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse { + return portal.convertMatrixPollVote(ctx, sender, evt) + } else if evt.Type == TypeMSC3381PollStart { + return portal.convertMatrixPollStart(ctx, sender, evt) + } + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) + } + var editRootMsg *database.Message + if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits { + editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID) + if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User { + return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message + } + if content.NewContent != nil { + content = content.NewContent + } + } + + msg := &waProto.Message{} + ctxInfo := portal.generateContextInfo(content.RelatesTo) relaybotFormatted := false if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) { if !portal.HasRelaybot() { - return nil, sender, errUserNotLoggedIn + return nil, sender, nil, errUserNotLoggedIn } relaybotFormatted = portal.addRelaybotFormat(sender, content) sender = portal.GetRelayUser() @@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MsgText, event.MsgEmote, event.MsgNotice: text := content.Body if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices { - return nil, sender, errMNoticeDisabled + return nil, sender, nil, errMNoticeDisabled } if content.Format == event.FormatHTML { text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody) @@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev } msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{ Text: &text, - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, } hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage) if ctx.Err() != nil { - return nil, nil, ctx.Err() + return nil, nil, nil, ctx.Err() } if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview { // No need for extended message @@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MsgImage: media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage) if media == nil { - return nil, sender, err + return nil, sender, nil, err } ctxInfo.MentionedJid = media.MentionedJIDs msg.ImageMessage = &waProto.ImageMessage{ - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, Caption: &media.Caption, JpegThumbnail: media.Thumbnail, Url: &media.URL, @@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MessageType(event.EventSticker.Type): media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage) if media == nil { - return nil, sender, err + return nil, sender, nil, err } ctxInfo.MentionedJid = media.MentionedJIDs msg.StickerMessage = &waProto.StickerMessage{ - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, PngThumbnail: media.Thumbnail, Url: &media.URL, MediaKey: media.MediaKey, @@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev gifPlayback := content.GetInfo().MimeType == "image/gif" media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo) if media == nil { - return nil, sender, err + return nil, sender, nil, err } duration := uint32(content.GetInfo().Duration / 1000) ctxInfo.MentionedJid = media.MentionedJIDs msg.VideoMessage = &waProto.VideoMessage{ - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, Caption: &media.Caption, JpegThumbnail: media.Thumbnail, Url: &media.URL, @@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MsgAudio: media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio) if media == nil { - return nil, sender, err + return nil, sender, nil, err } duration := uint32(content.GetInfo().Duration / 1000) msg.AudioMessage = &waProto.AudioMessage{ - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, Url: &media.URL, MediaKey: media.MediaKey, Mimetype: &content.GetInfo().MimeType, @@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MsgFile: media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument) if media == nil { - return nil, sender, err + return nil, sender, nil, err } msg.DocumentMessage = &waProto.DocumentMessage{ - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, Caption: &media.Caption, JpegThumbnail: media.Thumbnail, Url: &media.URL, @@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev case event.MsgLocation: lat, long, err := parseGeoURI(content.GeoURI) if err != nil { - return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err) + return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err) } msg.LocationMessage = &waProto.LocationMessage{ DegreesLatitude: &lat, DegreesLongitude: &long, Comment: &content.Body, - ContextInfo: &ctxInfo, + ContextInfo: ctxInfo, } default: - return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType) + return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType) } if editRootMsg != nil { @@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev } } - return msg, sender, nil + return msg, sender, nil, nil } func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo { @@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing start := time.Now() ms := metricSender{portal: portal, timings: &timings} - allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse + allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart if err := portal.canBridgeFrom(sender, allowRelay); err != nil { go ms.sendMessageMetrics(evt, err, "Ignoring", true) return @@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing timings.preproc = time.Since(start) start = time.Now() - msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt) + msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt) timings.convert = time.Since(start) if msg == nil { go ms.sendMessageMetrics(evt, err, "Error converting", true) return } dbMsgType := database.MsgNormal - if msg.EditedMessage == nil { + if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil { + dbMsgType = database.MsgMatrixPoll + } else if msg.EditedMessage == nil { portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true) } else { dbMsgType = database.MsgEdit @@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing } else { info.ID = dbMsg.JID } + if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil { + dbMsg.PutPollOptions(extraMeta.PollOptions) + } portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID) start = time.Now() resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)