Add support for creating polls from Matrix

This commit is contained in:
Tulir Asokan 2022-12-23 15:17:57 +02:00
parent 52876bb607
commit 0305680317
10 changed files with 374 additions and 73 deletions

View file

@ -1,3 +1,7 @@
# unreleased
* Added support for sending polls from Matrix to WhatsApp.
# v0.8.0 (2022-12-16)
* Added support for bridging polls from WhatsApp and votes in both directions.

View file

@ -6,7 +6,7 @@
* [x] Location messages
* [x] Media/files
* [x] Replies
* [ ] Polls
* [x] Polls
* [x] Poll votes
* [x] Message redactions
* [x] Reactions

View file

@ -134,11 +134,12 @@ const (
type MessageType string
const (
MsgUnknown MessageType = ""
MsgFake MessageType = "fake"
MsgNormal MessageType = "message"
MsgReaction MessageType = "reaction"
MsgEdit MessageType = "edit"
MsgUnknown MessageType = ""
MsgFake MessageType = "fake"
MsgNormal MessageType = "message"
MsgReaction MessageType = "reaction"
MsgEdit MessageType = "edit"
MsgMatrixPoll MessageType = "matrix-poll"
)
type Message struct {

118
database/polloption.go Normal file
View file

@ -0,0 +1,118 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"fmt"
"strings"
"github.com/lib/pq"
"maunium.net/go/mautrix/util/dbutil"
)
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
var hash []byte
err = rows.Scan(&id, &hash)
if err != nil {
// return below
} else if len(hash) != 32 {
err = fmt.Errorf("unexpected hash length %d", len(hash))
} else {
hashArr = *(*[32]byte)(hash)
}
return
}
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
args := make([]any, len(opts)*2+1)
placeholders := make([]string, len(opts))
args[0] = msg.MXID
i := 0
for hash, id := range opts {
args[i*2+1] = id
hashCopy := hash
args[i*2+2] = hashCopy[:]
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
i++
}
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
_, err := msg.db.Exec(query, args...)
if err != nil {
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
}
}
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(hashes)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
args = make([]any, len(hashes)+1)
args[0] = msg.MXID
for i, hash := range hashes {
args[i+1] = hash
}
}
ids := make(map[[32]byte]string, len(hashes))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
break
}
ids[hash] = id
}
}
return ids
}
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(ids)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
args = make([]any, len(ids)+1)
args[0] = msg.MXID
for i, id := range ids {
args[i+1] = id
}
}
hashes := make(map[string][32]byte, len(ids))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
break
}
hashes[id] = hash
}
}
return hashes
}

View file

@ -1,4 +1,4 @@
-- v0 -> v52: Latest revision
-- v0 -> v54: Latest revision
CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
@ -40,6 +40,7 @@ CREATE TABLE portal (
PRIMARY KEY (jid, receiver)
);
CREATE INDEX portal_parent_group_idx ON portal(parent_group);
CREATE TABLE puppet (
username TEXT PRIMARY KEY,
@ -79,6 +80,16 @@ CREATE TABLE message (
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
);
CREATE TABLE poll_option_id (
msg_mxid TEXT,
opt_id TEXT,
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
PRIMARY KEY (msg_mxid, opt_id),
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE reaction (
chat_jid TEXT,
chat_receiver TEXT,

View file

@ -0,0 +1,11 @@
-- v54: Store mapping for poll option IDs from Matrix
CREATE TABLE poll_option_id (
msg_mxid TEXT,
opt_id TEXT,
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
PRIMARY KEY (msg_mxid, opt_id),
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
);

View file

@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)
var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids"
const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions"
type Formatter struct {
bridge *WABridge
@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter {
Newline: "\n",
PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
if mxid[0] == '@' {
_, disableMentions := ctx[disableMentionsContextKey]
if mxid[0] == '@' && !disableMentions {
puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
if puppet != nil {
jids, ok := ctx[mentionedJIDsContextKey].([]string)
@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter {
return "@" + puppet.JID.User
}
}
return mxid
return displayname
},
BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) {
mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
return result, mentionedJIDs
}
func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string {
ctx := make(format.Context)
ctx[disableMentionsContextKey] = true
return formatter.matrixHTMLParser.Parse(html, ctx)
}

View file

@ -87,8 +87,9 @@ func (br *WABridge) Init() {
// TODO this is a weird place for this
br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
Segment.log = br.Log.Sub("Segment")
Segment.key = br.Config.SegmentKey

View file

@ -52,6 +52,8 @@ var (
errTargetIsFake = errors.New("target is a fake event")
errReactionSentBySomeoneElse = errors.New("target reaction was sent by someone else")
errDMSentByOtherUser = errors.New("target message was sent by the other user in a DM")
errPollMissingQuestion = errors.New("poll message is missing question")
errPollDuplicateOption = errors.New("poll options must be unique")
errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
errBroadcastSendDisabled = errors.New("sending status messages is disabled")
@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
case errors.Is(err, errMNoticeDisabled):
return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
case errors.Is(err, errMediaUnsupportedType):
case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption):
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
case errors.Is(err, errTimeoutBeforeHandling):
return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse:
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}

261
portal.go
View file

@ -19,6 +19,7 @@ package main
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
timings.implicitRR = time.Since(implicitRRStart)
switch msg.evt.Type {
case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse:
case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart:
portal.HandleMatrixMessage(msg.user, msg.evt, timings)
case event.EventRedaction:
portal.HandleMatrixRedaction(msg.user, msg.evt)
@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m
}
func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
if !portal.bridge.Config.Bridge.ExtEvPolls {
return nil
}
pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
if pollMessage == nil {
portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
return nil
}
selectedHashes := make([]string, len(vote.GetSelectedOptions()))
for i, opt := range vote.GetSelectedOptions() {
selectedHashes[i] = hex.EncodeToString(opt)
if pollMessage.Type == database.MsgMatrixPoll {
mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions())
for i, opt := range vote.GetSelectedOptions() {
if len(opt) != 32 {
portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID)
continue
}
var ok bool
selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)]
if !ok {
portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID)
}
}
} else {
for i, opt := range vote.GetSelectedOptions() {
selectedHashes[i] = hex.EncodeToString(opt)
}
}
evtType := TypeMSC3881PollResponse
evtType := TypeMSC3381PollResponse
//if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
// evtType = TypeMSC3881V2PollResponse
// evtType = TypeMSC3381V2PollResponse
//}
return &ConvertedMessage{
Intent: intent,
@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m
}
evtType := event.EventMessage
if portal.bridge.Config.Bridge.ExtEvPolls {
evtType.Type = "org.matrix.msc3381.poll.start"
evtType = TypeMSC3381PollStart
}
//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
// evtType.Type = "org.matrix.msc3381.v2.poll.start"
@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
}
var (
TypeMSC3881PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"}
TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"}
)
type PollResponseContent struct {
@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = *rel
}
func init() {
event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{})
type MSC1767Message struct {
Text string `json:"org.matrix.msc1767.text,omitempty"`
HTML string `json:"org.matrix.msc1767.html,omitempty"`
Message []struct {
MimeType string `json:"mimetype"`
Body string `json:"body"`
} `json:"org.matrix.msc1767.message,omitempty"`
}
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) {
for _, part := range msg.Message {
if part.MimeType == "text/html" && msg.HTML == "" {
msg.HTML = part.Body
} else if part.MimeType == "text/plain" && msg.Text == "" {
msg.Text = part.Body
}
}
if msg.HTML != "" {
if mentions {
return portal.bridge.Formatter.ParseMatrix(msg.HTML)
} else {
return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil
}
}
return msg.Text, nil
}
type PollStartContent struct {
RelatesTo *event.RelatesTo `json:"m.relates_to"`
PollStart struct {
Kind string `json:"kind"`
MaxSelections int `json:"max_selections"`
Question MSC1767Message `json:"question"`
Answers []struct {
ID string `json:"id"`
MSC1767Message
} `json:"answers"`
} `json:"org.matrix.msc3381.poll.start"`
}
func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
if content.RelatesTo == nil {
content.RelatesTo = &event.RelatesTo{}
}
return content.RelatesTo
}
func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
return content.RelatesTo
}
func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
content.RelatesTo = rel
}
func init() {
event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{})
event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
}
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
content, ok := evt.Content.Parsed.(*PollResponseContent)
if !ok {
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
}
var answers []string
if content.V1Response.Answers != nil {
@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
}
pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
if pollMsg == nil {
return nil, sender, errTargetNotFound
return nil, sender, nil, errTargetNotFound
}
pollMsgInfo := &types.MessageInfo{
MessageSource: types.MessageSource{
@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
Type: "poll",
}
optionHashes := make([][]byte, 0, len(answers))
for _, selection := range answers {
hash, _ := hex.DecodeString(selection)
if hash != nil && len(hash) == 32 {
optionHashes = append(optionHashes, hash)
if pollMsg.Type == database.MsgMatrixPoll {
mappedAnswers := pollMsg.GetPollOptionHashes(answers)
for _, selection := range answers {
hash, ok := mappedAnswers[selection]
if ok {
optionHashes = append(optionHashes, hash[:])
} else {
portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID)
}
}
} else {
for _, selection := range answers {
hash, _ := hex.DecodeString(selection)
if hash != nil && len(hash) == 32 {
optionHashes = append(optionHashes, hash)
}
}
}
pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
SelectedOptions: optionHashes,
})
return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err
return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err
}
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse {
return portal.convertMatrixPollVote(ctx, sender, evt)
}
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
content, ok := evt.Content.Parsed.(*PollStartContent)
if !ok {
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
}
var editRootMsg *database.Message
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message
maxAnswers := content.PollStart.MaxSelections
if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
maxAnswers = 0
}
fmt.Printf("%+v\n", content.PollStart)
ctxInfo := portal.generateContextInfo(content.RelatesTo)
var question string
question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true)
if len(question) == 0 {
return nil, sender, nil, errPollMissingQuestion
}
options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers))
optionMap := make(map[[32]byte]string, len(options))
for i, opt := range content.PollStart.Answers {
body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false)
hash := sha256.Sum256([]byte(body))
if _, alreadyExists := optionMap[hash]; alreadyExists {
portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body)
return nil, sender, nil, errPollDuplicateOption
}
if content.NewContent != nil {
content = content.NewContent
optionMap[hash] = opt.ID
options[i] = &waProto.PollCreationMessage_Option{
OptionName: proto.String(body),
}
}
secret := make([]byte, 32)
_, err := rand.Read(secret)
return &waProto.Message{
PollCreationMessage: &waProto.PollCreationMessage{
Name: proto.String(question),
Options: options,
SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)),
ContextInfo: ctxInfo,
},
MessageContextInfo: &waProto.MessageContextInfo{
MessageSecret: secret,
},
}, sender, &extraConvertMeta{PollOptions: optionMap}, err
}
msg := &waProto.Message{}
func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo {
var ctxInfo waProto.ContextInfo
replyToID := content.RelatesTo.GetReplyTo()
replyToID := relatesTo.GetReplyTo()
if len(replyToID) > 0 {
replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal {
if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) {
ctxInfo.StanzaId = &replyToMsg.JID
ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
// Using blank content here seems to work fine on all official WhatsApp apps.
@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if portal.ExpirationTime != 0 {
ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
}
return &ctxInfo
}
type extraConvertMeta struct {
PollOptions map[[32]byte]string
}
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse {
return portal.convertMatrixPollVote(ctx, sender, evt)
} else if evt.Type == TypeMSC3381PollStart {
return portal.convertMatrixPollStart(ctx, sender, evt)
}
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
}
var editRootMsg *database.Message
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message
}
if content.NewContent != nil {
content = content.NewContent
}
}
msg := &waProto.Message{}
ctxInfo := portal.generateContextInfo(content.RelatesTo)
relaybotFormatted := false
if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
if !portal.HasRelaybot() {
return nil, sender, errUserNotLoggedIn
return nil, sender, nil, errUserNotLoggedIn
}
relaybotFormatted = portal.addRelaybotFormat(sender, content)
sender = portal.GetRelayUser()
@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgText, event.MsgEmote, event.MsgNotice:
text := content.Body
if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
return nil, sender, errMNoticeDisabled
return nil, sender, nil, errMNoticeDisabled
}
if content.Format == event.FormatHTML {
text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
}
msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
Text: &text,
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
}
hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
if ctx.Err() != nil {
return nil, nil, ctx.Err()
return nil, nil, nil, ctx.Err()
}
if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
// No need for extended message
@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgImage:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
if media == nil {
return nil, sender, err
return nil, sender, nil, err
}
ctxInfo.MentionedJid = media.MentionedJIDs
msg.ImageMessage = &waProto.ImageMessage{
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
Caption: &media.Caption,
JpegThumbnail: media.Thumbnail,
Url: &media.URL,
@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MessageType(event.EventSticker.Type):
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
if media == nil {
return nil, sender, err
return nil, sender, nil, err
}
ctxInfo.MentionedJid = media.MentionedJIDs
msg.StickerMessage = &waProto.StickerMessage{
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
PngThumbnail: media.Thumbnail,
Url: &media.URL,
MediaKey: media.MediaKey,
@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
gifPlayback := content.GetInfo().MimeType == "image/gif"
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
if media == nil {
return nil, sender, err
return nil, sender, nil, err
}
duration := uint32(content.GetInfo().Duration / 1000)
ctxInfo.MentionedJid = media.MentionedJIDs
msg.VideoMessage = &waProto.VideoMessage{
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
Caption: &media.Caption,
JpegThumbnail: media.Thumbnail,
Url: &media.URL,
@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgAudio:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
if media == nil {
return nil, sender, err
return nil, sender, nil, err
}
duration := uint32(content.GetInfo().Duration / 1000)
msg.AudioMessage = &waProto.AudioMessage{
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
Url: &media.URL,
MediaKey: media.MediaKey,
Mimetype: &content.GetInfo().MimeType,
@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgFile:
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
if media == nil {
return nil, sender, err
return nil, sender, nil, err
}
msg.DocumentMessage = &waProto.DocumentMessage{
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
Caption: &media.Caption,
JpegThumbnail: media.Thumbnail,
Url: &media.URL,
@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
case event.MsgLocation:
lat, long, err := parseGeoURI(content.GeoURI)
if err != nil {
return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
}
msg.LocationMessage = &waProto.LocationMessage{
DegreesLatitude: &lat,
DegreesLongitude: &long,
Comment: &content.Body,
ContextInfo: &ctxInfo,
ContextInfo: ctxInfo,
}
default:
return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
}
if editRootMsg != nil {
@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
}
}
return msg, sender, nil
return msg, sender, nil, nil
}
func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
start := time.Now()
ms := metricSender{portal: portal, timings: &timings}
allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse
allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart
if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
go ms.sendMessageMetrics(evt, err, "Ignoring", true)
return
@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
timings.preproc = time.Since(start)
start = time.Now()
msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt)
msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt)
timings.convert = time.Since(start)
if msg == nil {
go ms.sendMessageMetrics(evt, err, "Error converting", true)
return
}
dbMsgType := database.MsgNormal
if msg.EditedMessage == nil {
if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil {
dbMsgType = database.MsgMatrixPoll
} else if msg.EditedMessage == nil {
portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
} else {
dbMsgType = database.MsgEdit
@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
} else {
info.ID = dbMsg.JID
}
if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil {
dbMsg.PutPollOptions(extraMeta.PollOptions)
}
portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
start = time.Now()
resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)