forked from MirrorHub/mautrix-whatsapp
Add support for creating polls from Matrix
This commit is contained in:
parent
52876bb607
commit
0305680317
10 changed files with 374 additions and 73 deletions
|
@ -1,3 +1,7 @@
|
|||
# unreleased
|
||||
|
||||
* Added support for sending polls from Matrix to WhatsApp.
|
||||
|
||||
# v0.8.0 (2022-12-16)
|
||||
|
||||
* Added support for bridging polls from WhatsApp and votes in both directions.
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
* [x] Location messages
|
||||
* [x] Media/files
|
||||
* [x] Replies
|
||||
* [ ] Polls
|
||||
* [x] Polls
|
||||
* [x] Poll votes
|
||||
* [x] Message redactions
|
||||
* [x] Reactions
|
||||
|
|
|
@ -134,11 +134,12 @@ const (
|
|||
type MessageType string
|
||||
|
||||
const (
|
||||
MsgUnknown MessageType = ""
|
||||
MsgFake MessageType = "fake"
|
||||
MsgNormal MessageType = "message"
|
||||
MsgReaction MessageType = "reaction"
|
||||
MsgEdit MessageType = "edit"
|
||||
MsgUnknown MessageType = ""
|
||||
MsgFake MessageType = "fake"
|
||||
MsgNormal MessageType = "message"
|
||||
MsgReaction MessageType = "reaction"
|
||||
MsgEdit MessageType = "edit"
|
||||
MsgMatrixPoll MessageType = "matrix-poll"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
|
|
118
database/polloption.go
Normal file
118
database/polloption.go
Normal file
|
@ -0,0 +1,118 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
|
||||
var hash []byte
|
||||
err = rows.Scan(&id, &hash)
|
||||
if err != nil {
|
||||
// return below
|
||||
} else if len(hash) != 32 {
|
||||
err = fmt.Errorf("unexpected hash length %d", len(hash))
|
||||
} else {
|
||||
hashArr = *(*[32]byte)(hash)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
|
||||
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
|
||||
args := make([]any, len(opts)*2+1)
|
||||
placeholders := make([]string, len(opts))
|
||||
args[0] = msg.MXID
|
||||
i := 0
|
||||
for hash, id := range opts {
|
||||
args[i*2+1] = id
|
||||
hashCopy := hash
|
||||
args[i*2+2] = hashCopy[:]
|
||||
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
|
||||
i++
|
||||
}
|
||||
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
|
||||
_, err := msg.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
|
||||
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
|
||||
var args []any
|
||||
if msg.db.Dialect == dbutil.Postgres {
|
||||
args = []any{msg.MXID, pq.Array(hashes)}
|
||||
} else {
|
||||
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
|
||||
args = make([]any, len(hashes)+1)
|
||||
args[0] = msg.MXID
|
||||
for i, hash := range hashes {
|
||||
args[i+1] = hash
|
||||
}
|
||||
}
|
||||
ids := make(map[[32]byte]string, len(hashes))
|
||||
rows, err := msg.db.Query(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
|
||||
} else {
|
||||
for rows.Next() {
|
||||
id, hash, err := scanPollOptionMapping(rows)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
|
||||
break
|
||||
}
|
||||
ids[hash] = id
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
|
||||
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
|
||||
var args []any
|
||||
if msg.db.Dialect == dbutil.Postgres {
|
||||
args = []any{msg.MXID, pq.Array(ids)}
|
||||
} else {
|
||||
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
|
||||
args = make([]any, len(ids)+1)
|
||||
args[0] = msg.MXID
|
||||
for i, id := range ids {
|
||||
args[i+1] = id
|
||||
}
|
||||
}
|
||||
hashes := make(map[string][32]byte, len(ids))
|
||||
rows, err := msg.db.Query(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
|
||||
} else {
|
||||
for rows.Next() {
|
||||
id, hash, err := scanPollOptionMapping(rows)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
|
||||
break
|
||||
}
|
||||
hashes[id] = hash
|
||||
}
|
||||
}
|
||||
return hashes
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v52: Latest revision
|
||||
-- v0 -> v54: Latest revision
|
||||
|
||||
CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
|
@ -40,6 +40,7 @@ CREATE TABLE portal (
|
|||
|
||||
PRIMARY KEY (jid, receiver)
|
||||
);
|
||||
CREATE INDEX portal_parent_group_idx ON portal(parent_group);
|
||||
|
||||
CREATE TABLE puppet (
|
||||
username TEXT PRIMARY KEY,
|
||||
|
@ -79,6 +80,16 @@ CREATE TABLE message (
|
|||
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE poll_option_id (
|
||||
msg_mxid TEXT,
|
||||
opt_id TEXT,
|
||||
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
|
||||
|
||||
PRIMARY KEY (msg_mxid, opt_id),
|
||||
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
|
||||
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE reaction (
|
||||
chat_jid TEXT,
|
||||
chat_receiver TEXT,
|
||||
|
|
11
database/upgrades/54-poll-option-id-map.sql
Normal file
11
database/upgrades/54-poll-option-id-map.sql
Normal file
|
@ -0,0 +1,11 @@
|
|||
-- v54: Store mapping for poll option IDs from Matrix
|
||||
|
||||
CREATE TABLE poll_option_id (
|
||||
msg_mxid TEXT,
|
||||
opt_id TEXT,
|
||||
opt_hash bytea CHECK ( length(opt_hash) = 32 ),
|
||||
|
||||
PRIMARY KEY (msg_mxid, opt_id),
|
||||
CONSTRAINT poll_option_unique_hash UNIQUE (msg_mxid, opt_hash),
|
||||
CONSTRAINT message_mxid_fkey FOREIGN KEY (msg_mxid) REFERENCES message(mxid) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
|
@ -35,7 +35,8 @@ var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)
|
|||
var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
|
||||
var inlineURLRegex = regexp.MustCompile(`\[(.+?)]\((.+?)\)`)
|
||||
|
||||
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
|
||||
const mentionedJIDsContextKey = "fi.mau.whatsapp.mentioned_jids"
|
||||
const disableMentionsContextKey = "fi.mau.whatsapp.no_mentions"
|
||||
|
||||
type Formatter struct {
|
||||
bridge *WABridge
|
||||
|
@ -55,7 +56,8 @@ func NewFormatter(bridge *WABridge) *Formatter {
|
|||
Newline: "\n",
|
||||
|
||||
PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
|
||||
if mxid[0] == '@' {
|
||||
_, disableMentions := ctx[disableMentionsContextKey]
|
||||
if mxid[0] == '@' && !disableMentions {
|
||||
puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
|
||||
if puppet != nil {
|
||||
jids, ok := ctx[mentionedJIDsContextKey].([]string)
|
||||
|
@ -67,7 +69,7 @@ func NewFormatter(bridge *WABridge) *Formatter {
|
|||
return "@" + puppet.JID.User
|
||||
}
|
||||
}
|
||||
return mxid
|
||||
return displayname
|
||||
},
|
||||
BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
|
||||
ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
|
||||
|
@ -151,3 +153,9 @@ func (formatter *Formatter) ParseMatrix(html string) (string, []string) {
|
|||
mentionedJIDs, _ := ctx[mentionedJIDsContextKey].([]string)
|
||||
return result, mentionedJIDs
|
||||
}
|
||||
|
||||
func (formatter *Formatter) ParseMatrixWithoutMentions(html string) string {
|
||||
ctx := make(format.Context)
|
||||
ctx[disableMentionsContextKey] = true
|
||||
return formatter.matrixHTMLParser.Parse(html, ctx)
|
||||
}
|
||||
|
|
5
main.go
5
main.go
|
@ -87,8 +87,9 @@ func (br *WABridge) Init() {
|
|||
|
||||
// TODO this is a weird place for this
|
||||
br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence)
|
||||
br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage)
|
||||
br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage)
|
||||
br.EventProcessor.On(TypeMSC3381PollStart, br.MatrixHandler.HandleMessage)
|
||||
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
|
||||
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
|
||||
|
||||
Segment.log = br.Log.Sub("Segment")
|
||||
Segment.key = br.Config.SegmentKey
|
||||
|
|
|
@ -52,6 +52,8 @@ var (
|
|||
errTargetIsFake = errors.New("target is a fake event")
|
||||
errReactionSentBySomeoneElse = errors.New("target reaction was sent by someone else")
|
||||
errDMSentByOtherUser = errors.New("target message was sent by the other user in a DM")
|
||||
errPollMissingQuestion = errors.New("poll message is missing question")
|
||||
errPollDuplicateOption = errors.New("poll options must be unique")
|
||||
|
||||
errBroadcastReactionNotSupported = errors.New("reacting to status messages is not currently supported")
|
||||
errBroadcastSendDisabled = errors.New("sending status messages is disabled")
|
||||
|
@ -76,7 +78,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
|
|||
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, ""
|
||||
case errors.Is(err, errMNoticeDisabled):
|
||||
return event.MessageStatusUnsupported, event.MessageStatusFail, true, false, ""
|
||||
case errors.Is(err, errMediaUnsupportedType):
|
||||
case errors.Is(err, errMediaUnsupportedType), errors.Is(err, errPollMissingQuestion), errors.Is(err, errPollDuplicateOption):
|
||||
return event.MessageStatusUnsupported, event.MessageStatusFail, true, true, err.Error()
|
||||
case errors.Is(err, errTimeoutBeforeHandling):
|
||||
return event.MessageStatusTooOld, event.MessageStatusRetriable, true, true, "the message was too old when it reached the bridge, so it was not handled"
|
||||
|
@ -185,8 +187,10 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
|
|||
msgType = "reaction"
|
||||
case event.EventRedaction:
|
||||
msgType = "redaction"
|
||||
case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse:
|
||||
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
|
||||
msgType = "poll response"
|
||||
case TypeMSC3381PollStart:
|
||||
msgType = "poll start"
|
||||
default:
|
||||
msgType = "unknown event"
|
||||
}
|
||||
|
|
261
portal.go
261
portal.go
|
@ -19,6 +19,7 @@ package main
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
@ -322,7 +323,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) {
|
|||
portal.handleMatrixReadReceipt(msg.user, "", evtTS, false)
|
||||
timings.implicitRR = time.Since(implicitRRStart)
|
||||
switch msg.evt.Type {
|
||||
case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse:
|
||||
case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart:
|
||||
portal.HandleMatrixMessage(msg.user, msg.evt, timings)
|
||||
case event.EventRedaction:
|
||||
portal.HandleMatrixRedaction(msg.user, msg.evt)
|
||||
|
@ -2267,9 +2268,6 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m
|
|||
}
|
||||
|
||||
func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage {
|
||||
if !portal.bridge.Config.Bridge.ExtEvPolls {
|
||||
return nil
|
||||
}
|
||||
pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId())
|
||||
if pollMessage == nil {
|
||||
portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId())
|
||||
|
@ -2284,13 +2282,28 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou
|
|||
return nil
|
||||
}
|
||||
selectedHashes := make([]string, len(vote.GetSelectedOptions()))
|
||||
for i, opt := range vote.GetSelectedOptions() {
|
||||
selectedHashes[i] = hex.EncodeToString(opt)
|
||||
if pollMessage.Type == database.MsgMatrixPoll {
|
||||
mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions())
|
||||
for i, opt := range vote.GetSelectedOptions() {
|
||||
if len(opt) != 32 {
|
||||
portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID)
|
||||
continue
|
||||
}
|
||||
var ok bool
|
||||
selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)]
|
||||
if !ok {
|
||||
portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, opt := range vote.GetSelectedOptions() {
|
||||
selectedHashes[i] = hex.EncodeToString(opt)
|
||||
}
|
||||
}
|
||||
|
||||
evtType := TypeMSC3881PollResponse
|
||||
evtType := TypeMSC3381PollResponse
|
||||
//if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
|
||||
// evtType = TypeMSC3881V2PollResponse
|
||||
// evtType = TypeMSC3381V2PollResponse
|
||||
//}
|
||||
return &ConvertedMessage{
|
||||
Intent: intent,
|
||||
|
@ -2341,7 +2354,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m
|
|||
}
|
||||
evtType := event.EventMessage
|
||||
if portal.bridge.Config.Bridge.ExtEvPolls {
|
||||
evtType.Type = "org.matrix.msc3381.poll.start"
|
||||
evtType = TypeMSC3381PollStart
|
||||
}
|
||||
//else if portal.bridge.Config.Bridge.ExtEvPolls == 2 {
|
||||
// evtType.Type = "org.matrix.msc3381.v2.poll.start"
|
||||
|
@ -3505,8 +3518,9 @@ func getUnstableWaveform(content map[string]interface{}) []byte {
|
|||
}
|
||||
|
||||
var (
|
||||
TypeMSC3881PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
|
||||
TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"}
|
||||
TypeMSC3381PollStart = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.start"}
|
||||
TypeMSC3381PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"}
|
||||
TypeMSC3381V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.v2.poll.response"}
|
||||
)
|
||||
|
||||
type PollResponseContent struct {
|
||||
|
@ -3532,15 +3546,71 @@ func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) {
|
|||
content.RelatesTo = *rel
|
||||
}
|
||||
|
||||
func init() {
|
||||
event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{})
|
||||
event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{})
|
||||
type MSC1767Message struct {
|
||||
Text string `json:"org.matrix.msc1767.text,omitempty"`
|
||||
HTML string `json:"org.matrix.msc1767.html,omitempty"`
|
||||
Message []struct {
|
||||
MimeType string `json:"mimetype"`
|
||||
Body string `json:"body"`
|
||||
} `json:"org.matrix.msc1767.message,omitempty"`
|
||||
}
|
||||
|
||||
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
|
||||
func (portal *Portal) msc1767ToWhatsApp(msg MSC1767Message, mentions bool) (string, []string) {
|
||||
for _, part := range msg.Message {
|
||||
if part.MimeType == "text/html" && msg.HTML == "" {
|
||||
msg.HTML = part.Body
|
||||
} else if part.MimeType == "text/plain" && msg.Text == "" {
|
||||
msg.Text = part.Body
|
||||
}
|
||||
}
|
||||
if msg.HTML != "" {
|
||||
if mentions {
|
||||
return portal.bridge.Formatter.ParseMatrix(msg.HTML)
|
||||
} else {
|
||||
return portal.bridge.Formatter.ParseMatrixWithoutMentions(msg.HTML), nil
|
||||
}
|
||||
}
|
||||
return msg.Text, nil
|
||||
}
|
||||
|
||||
type PollStartContent struct {
|
||||
RelatesTo *event.RelatesTo `json:"m.relates_to"`
|
||||
PollStart struct {
|
||||
Kind string `json:"kind"`
|
||||
MaxSelections int `json:"max_selections"`
|
||||
Question MSC1767Message `json:"question"`
|
||||
Answers []struct {
|
||||
ID string `json:"id"`
|
||||
MSC1767Message
|
||||
} `json:"answers"`
|
||||
} `json:"org.matrix.msc3381.poll.start"`
|
||||
}
|
||||
|
||||
func (content *PollStartContent) GetRelatesTo() *event.RelatesTo {
|
||||
if content.RelatesTo == nil {
|
||||
content.RelatesTo = &event.RelatesTo{}
|
||||
}
|
||||
return content.RelatesTo
|
||||
}
|
||||
|
||||
func (content *PollStartContent) OptionalGetRelatesTo() *event.RelatesTo {
|
||||
return content.RelatesTo
|
||||
}
|
||||
|
||||
func (content *PollStartContent) SetRelatesTo(rel *event.RelatesTo) {
|
||||
content.RelatesTo = rel
|
||||
}
|
||||
|
||||
func init() {
|
||||
event.TypeMap[TypeMSC3381PollResponse] = reflect.TypeOf(PollResponseContent{})
|
||||
event.TypeMap[TypeMSC3381V2PollResponse] = reflect.TypeOf(PollResponseContent{})
|
||||
event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{})
|
||||
}
|
||||
|
||||
func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
|
||||
content, ok := evt.Content.Parsed.(*PollResponseContent)
|
||||
if !ok {
|
||||
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
|
||||
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
|
||||
}
|
||||
var answers []string
|
||||
if content.V1Response.Answers != nil {
|
||||
|
@ -3550,7 +3620,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
|
|||
}
|
||||
pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID)
|
||||
if pollMsg == nil {
|
||||
return nil, sender, errTargetNotFound
|
||||
return nil, sender, nil, errTargetNotFound
|
||||
}
|
||||
pollMsgInfo := &types.MessageInfo{
|
||||
MessageSource: types.MessageSource{
|
||||
|
@ -3563,43 +3633,81 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt
|
|||
Type: "poll",
|
||||
}
|
||||
optionHashes := make([][]byte, 0, len(answers))
|
||||
for _, selection := range answers {
|
||||
hash, _ := hex.DecodeString(selection)
|
||||
if hash != nil && len(hash) == 32 {
|
||||
optionHashes = append(optionHashes, hash)
|
||||
if pollMsg.Type == database.MsgMatrixPoll {
|
||||
mappedAnswers := pollMsg.GetPollOptionHashes(answers)
|
||||
for _, selection := range answers {
|
||||
hash, ok := mappedAnswers[selection]
|
||||
if ok {
|
||||
optionHashes = append(optionHashes, hash[:])
|
||||
} else {
|
||||
portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, selection := range answers {
|
||||
hash, _ := hex.DecodeString(selection)
|
||||
if hash != nil && len(hash) == 32 {
|
||||
optionHashes = append(optionHashes, hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{
|
||||
SelectedOptions: optionHashes,
|
||||
})
|
||||
return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err
|
||||
return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err
|
||||
}
|
||||
|
||||
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) {
|
||||
if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse {
|
||||
return portal.convertMatrixPollVote(ctx, sender, evt)
|
||||
}
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
|
||||
content, ok := evt.Content.Parsed.(*PollStartContent)
|
||||
if !ok {
|
||||
return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
|
||||
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
|
||||
}
|
||||
var editRootMsg *database.Message
|
||||
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
|
||||
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
|
||||
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
|
||||
return nil, sender, fmt.Errorf("edit rejected") // TODO more specific error message
|
||||
maxAnswers := content.PollStart.MaxSelections
|
||||
if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 {
|
||||
maxAnswers = 0
|
||||
}
|
||||
fmt.Printf("%+v\n", content.PollStart)
|
||||
ctxInfo := portal.generateContextInfo(content.RelatesTo)
|
||||
var question string
|
||||
question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true)
|
||||
if len(question) == 0 {
|
||||
return nil, sender, nil, errPollMissingQuestion
|
||||
}
|
||||
options := make([]*waProto.PollCreationMessage_Option, len(content.PollStart.Answers))
|
||||
optionMap := make(map[[32]byte]string, len(options))
|
||||
for i, opt := range content.PollStart.Answers {
|
||||
body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false)
|
||||
hash := sha256.Sum256([]byte(body))
|
||||
if _, alreadyExists := optionMap[hash]; alreadyExists {
|
||||
portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body)
|
||||
return nil, sender, nil, errPollDuplicateOption
|
||||
}
|
||||
if content.NewContent != nil {
|
||||
content = content.NewContent
|
||||
optionMap[hash] = opt.ID
|
||||
options[i] = &waProto.PollCreationMessage_Option{
|
||||
OptionName: proto.String(body),
|
||||
}
|
||||
}
|
||||
secret := make([]byte, 32)
|
||||
_, err := rand.Read(secret)
|
||||
return &waProto.Message{
|
||||
PollCreationMessage: &waProto.PollCreationMessage{
|
||||
Name: proto.String(question),
|
||||
Options: options,
|
||||
SelectableOptionsCount: proto.Uint32(uint32(maxAnswers)),
|
||||
ContextInfo: ctxInfo,
|
||||
},
|
||||
MessageContextInfo: &waProto.MessageContextInfo{
|
||||
MessageSecret: secret,
|
||||
},
|
||||
}, sender, &extraConvertMeta{PollOptions: optionMap}, err
|
||||
}
|
||||
|
||||
msg := &waProto.Message{}
|
||||
func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo {
|
||||
var ctxInfo waProto.ContextInfo
|
||||
replyToID := content.RelatesTo.GetReplyTo()
|
||||
replyToID := relatesTo.GetReplyTo()
|
||||
if len(replyToID) > 0 {
|
||||
replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID)
|
||||
if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal {
|
||||
if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) {
|
||||
ctxInfo.StanzaId = &replyToMsg.JID
|
||||
ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String())
|
||||
// Using blank content here seems to work fine on all official WhatsApp apps.
|
||||
|
@ -3613,10 +3721,40 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
if portal.ExpirationTime != 0 {
|
||||
ctxInfo.Expiration = proto.Uint32(portal.ExpirationTime)
|
||||
}
|
||||
return &ctxInfo
|
||||
}
|
||||
|
||||
type extraConvertMeta struct {
|
||||
PollOptions map[[32]byte]string
|
||||
}
|
||||
|
||||
func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) {
|
||||
if evt.Type == TypeMSC3381PollResponse || evt.Type == TypeMSC3381V2PollResponse {
|
||||
return portal.convertMatrixPollVote(ctx, sender, evt)
|
||||
} else if evt.Type == TypeMSC3381PollStart {
|
||||
return portal.convertMatrixPollStart(ctx, sender, evt)
|
||||
}
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok {
|
||||
return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed)
|
||||
}
|
||||
var editRootMsg *database.Message
|
||||
if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" && portal.bridge.Config.Bridge.SendWhatsAppEdits {
|
||||
editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID)
|
||||
if editRootMsg == nil || editRootMsg.Type != database.MsgNormal || editRootMsg.IsFakeJID() || editRootMsg.Sender.User != sender.JID.User {
|
||||
return nil, sender, nil, fmt.Errorf("edit rejected") // TODO more specific error message
|
||||
}
|
||||
if content.NewContent != nil {
|
||||
content = content.NewContent
|
||||
}
|
||||
}
|
||||
|
||||
msg := &waProto.Message{}
|
||||
ctxInfo := portal.generateContextInfo(content.RelatesTo)
|
||||
relaybotFormatted := false
|
||||
if !sender.IsLoggedIn() || (portal.IsPrivateChat() && sender.JID.User != portal.Key.Receiver.User) {
|
||||
if !portal.HasRelaybot() {
|
||||
return nil, sender, errUserNotLoggedIn
|
||||
return nil, sender, nil, errUserNotLoggedIn
|
||||
}
|
||||
relaybotFormatted = portal.addRelaybotFormat(sender, content)
|
||||
sender = portal.GetRelayUser()
|
||||
|
@ -3637,7 +3775,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MsgText, event.MsgEmote, event.MsgNotice:
|
||||
text := content.Body
|
||||
if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
|
||||
return nil, sender, errMNoticeDisabled
|
||||
return nil, sender, nil, errMNoticeDisabled
|
||||
}
|
||||
if content.Format == event.FormatHTML {
|
||||
text, ctxInfo.MentionedJid = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
|
||||
|
@ -3647,11 +3785,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
}
|
||||
msg.ExtendedTextMessage = &waProto.ExtendedTextMessage{
|
||||
Text: &text,
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
}
|
||||
hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage)
|
||||
if ctx.Err() != nil {
|
||||
return nil, nil, ctx.Err()
|
||||
return nil, nil, nil, ctx.Err()
|
||||
}
|
||||
if ctxInfo.StanzaId == nil && ctxInfo.MentionedJid == nil && ctxInfo.Expiration == nil && !hasPreview {
|
||||
// No need for extended message
|
||||
|
@ -3661,11 +3799,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MsgImage:
|
||||
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
|
||||
if media == nil {
|
||||
return nil, sender, err
|
||||
return nil, sender, nil, err
|
||||
}
|
||||
ctxInfo.MentionedJid = media.MentionedJIDs
|
||||
msg.ImageMessage = &waProto.ImageMessage{
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
Caption: &media.Caption,
|
||||
JpegThumbnail: media.Thumbnail,
|
||||
Url: &media.URL,
|
||||
|
@ -3678,11 +3816,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MessageType(event.EventSticker.Type):
|
||||
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaImage)
|
||||
if media == nil {
|
||||
return nil, sender, err
|
||||
return nil, sender, nil, err
|
||||
}
|
||||
ctxInfo.MentionedJid = media.MentionedJIDs
|
||||
msg.StickerMessage = &waProto.StickerMessage{
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
PngThumbnail: media.Thumbnail,
|
||||
Url: &media.URL,
|
||||
MediaKey: media.MediaKey,
|
||||
|
@ -3695,12 +3833,12 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
gifPlayback := content.GetInfo().MimeType == "image/gif"
|
||||
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaVideo)
|
||||
if media == nil {
|
||||
return nil, sender, err
|
||||
return nil, sender, nil, err
|
||||
}
|
||||
duration := uint32(content.GetInfo().Duration / 1000)
|
||||
ctxInfo.MentionedJid = media.MentionedJIDs
|
||||
msg.VideoMessage = &waProto.VideoMessage{
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
Caption: &media.Caption,
|
||||
JpegThumbnail: media.Thumbnail,
|
||||
Url: &media.URL,
|
||||
|
@ -3715,11 +3853,11 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MsgAudio:
|
||||
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaAudio)
|
||||
if media == nil {
|
||||
return nil, sender, err
|
||||
return nil, sender, nil, err
|
||||
}
|
||||
duration := uint32(content.GetInfo().Duration / 1000)
|
||||
msg.AudioMessage = &waProto.AudioMessage{
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
Url: &media.URL,
|
||||
MediaKey: media.MediaKey,
|
||||
Mimetype: &content.GetInfo().MimeType,
|
||||
|
@ -3738,10 +3876,10 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MsgFile:
|
||||
media, err := portal.preprocessMatrixMedia(ctx, sender, relaybotFormatted, content, evt.ID, whatsmeow.MediaDocument)
|
||||
if media == nil {
|
||||
return nil, sender, err
|
||||
return nil, sender, nil, err
|
||||
}
|
||||
msg.DocumentMessage = &waProto.DocumentMessage{
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
Caption: &media.Caption,
|
||||
JpegThumbnail: media.Thumbnail,
|
||||
Url: &media.URL,
|
||||
|
@ -3764,16 +3902,16 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
case event.MsgLocation:
|
||||
lat, long, err := parseGeoURI(content.GeoURI)
|
||||
if err != nil {
|
||||
return nil, sender, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
|
||||
return nil, sender, nil, fmt.Errorf("%w: %v", errInvalidGeoURI, err)
|
||||
}
|
||||
msg.LocationMessage = &waProto.LocationMessage{
|
||||
DegreesLatitude: &lat,
|
||||
DegreesLongitude: &long,
|
||||
Comment: &content.Body,
|
||||
ContextInfo: &ctxInfo,
|
||||
ContextInfo: ctxInfo,
|
||||
}
|
||||
default:
|
||||
return nil, sender, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
|
||||
return nil, sender, nil, fmt.Errorf("%w %q", errUnknownMsgType, content.MsgType)
|
||||
}
|
||||
|
||||
if editRootMsg != nil {
|
||||
|
@ -3795,7 +3933,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
|
|||
}
|
||||
}
|
||||
|
||||
return msg, sender, nil
|
||||
return msg, sender, nil, nil
|
||||
}
|
||||
|
||||
func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo {
|
||||
|
@ -3815,7 +3953,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
|
|||
start := time.Now()
|
||||
ms := metricSender{portal: portal, timings: &timings}
|
||||
|
||||
allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse
|
||||
allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart
|
||||
if err := portal.canBridgeFrom(sender, allowRelay); err != nil {
|
||||
go ms.sendMessageMetrics(evt, err, "Ignoring", true)
|
||||
return
|
||||
|
@ -3875,14 +4013,16 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
|
|||
|
||||
timings.preproc = time.Since(start)
|
||||
start = time.Now()
|
||||
msg, sender, err := portal.convertMatrixMessage(ctx, sender, evt)
|
||||
msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt)
|
||||
timings.convert = time.Since(start)
|
||||
if msg == nil {
|
||||
go ms.sendMessageMetrics(evt, err, "Error converting", true)
|
||||
return
|
||||
}
|
||||
dbMsgType := database.MsgNormal
|
||||
if msg.EditedMessage == nil {
|
||||
if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil {
|
||||
dbMsgType = database.MsgMatrixPoll
|
||||
} else if msg.EditedMessage == nil {
|
||||
portal.MarkDisappearing(nil, origEvtID, portal.ExpirationTime, true)
|
||||
} else {
|
||||
dbMsgType = database.MsgEdit
|
||||
|
@ -3893,6 +4033,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
|
|||
} else {
|
||||
info.ID = dbMsg.JID
|
||||
}
|
||||
if dbMsgType == database.MsgMatrixPoll && extraMeta != nil && extraMeta.PollOptions != nil {
|
||||
dbMsg.PutPollOptions(extraMeta.PollOptions)
|
||||
}
|
||||
portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID)
|
||||
start = time.Now()
|
||||
resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, info.ID, msg)
|
||||
|
|
Loading…
Reference in a new issue