From b8dc3c0e569a372ea30cea67905d45a7cfb7e32a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Nov 2022 00:20:14 +0200 Subject: [PATCH] Bridge incoming votes in MSC3381 polls --- main.go | 2 ++ messagetracking.go | 2 ++ portal.go | 86 +++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/main.go b/main.go index 8d3e025..4880dcf 100644 --- a/main.go +++ b/main.go @@ -87,6 +87,8 @@ func (br *WABridge) Init() { // TODO this is a weird place for this br.EventProcessor.On(event.EphemeralEventPresence, br.HandlePresence) + br.EventProcessor.On(TypeMSC3881PollResponse, br.MatrixHandler.HandleMessage) + br.EventProcessor.On(TypeMSC3881V2PollResponse, br.MatrixHandler.HandleMessage) Segment.log = br.Log.Sub("Segment") Segment.key = br.Config.SegmentKey diff --git a/messagetracking.go b/messagetracking.go index 477b2cf..e83628f 100644 --- a/messagetracking.go +++ b/messagetracking.go @@ -185,6 +185,8 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin msgType = "reaction" case event.EventRedaction: msgType = "redaction" + case TypeMSC3881PollResponse, TypeMSC3881V2PollResponse: + msgType = "poll response" default: msgType = "unknown event" } diff --git a/portal.go b/portal.go index 43b5fa8..702da2d 100644 --- a/portal.go +++ b/portal.go @@ -33,6 +33,7 @@ import ( "math" "mime" "net/http" + "reflect" "runtime/debug" "strconv" "strings" @@ -315,7 +316,7 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg PortalMatrixMessage) { portal.handleMatrixReadReceipt(msg.user, "", evtTS, false) timings.implicitRR = time.Since(implicitRRStart) switch msg.evt.Type { - case event.EventMessage, event.EventSticker: + case event.EventMessage, event.EventSticker, TypeMSC3881V2PollResponse, TypeMSC3881PollResponse: portal.HandleMatrixMessage(msg.user, msg.evt, timings) case event.EventRedaction: portal.HandleMatrixRedaction(msg.user, msg.evt) @@ -2121,9 +2122,9 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou selectedHashes[i] = hex.EncodeToString(opt) } - evtType := event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} + evtType := TypeMSC3881PollResponse if portal.bridge.Config.Bridge.ExtEvPolls == 2 { - evtType.Type = "org.matrix.msc3381.v2.poll.response" + evtType = TypeMSC3881V2PollResponse } return &ConvertedMessage{ Intent: intent, @@ -3333,7 +3334,81 @@ func getUnstableWaveform(content map[string]interface{}) []byte { return output } +var ( + TypeMSC3881PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3381.poll.response"} + TypeMSC3881V2PollResponse = event.Type{Class: event.MessageEventType, Type: "org.matrix.msc3881.v2.poll.response"} +) + +type PollResponseContent struct { + RelatesTo event.RelatesTo `json:"m.relates_to"` + V1Response struct { + Answers []string `json:"answers"` + } `json:"org.matrix.msc3381.poll.response"` + V2Selections []string `json:"org.matrix.msc3381.v2.selections"` +} + +func (content *PollResponseContent) GetRelatesTo() *event.RelatesTo { + return &content.RelatesTo +} + +func (content *PollResponseContent) OptionalGetRelatesTo() *event.RelatesTo { + if content.RelatesTo.Type == "" { + return nil + } + return &content.RelatesTo +} + +func (content *PollResponseContent) SetRelatesTo(rel *event.RelatesTo) { + content.RelatesTo = *rel +} + +func init() { + event.TypeMap[TypeMSC3881PollResponse] = reflect.TypeOf(PollResponseContent{}) + event.TypeMap[TypeMSC3881V2PollResponse] = reflect.TypeOf(PollResponseContent{}) +} + +func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { + content, ok := evt.Content.Parsed.(*PollResponseContent) + if !ok { + return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) + } + var answers []string + if content.V1Response.Answers != nil { + answers = content.V1Response.Answers + } else if content.V2Selections != nil { + answers = content.V2Selections + } + pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID) + if pollMsg == nil { + return nil, sender, errTargetNotFound + } + pollMsgInfo := &types.MessageInfo{ + MessageSource: types.MessageSource{ + Chat: portal.Key.JID, + Sender: pollMsg.Sender, + IsFromMe: pollMsg.Sender.User == sender.JID.User, + IsGroup: portal.IsGroupChat(), + }, + ID: pollMsg.JID, + Type: "poll", + } + optionHashes := make([][]byte, 0, len(answers)) + for _, selection := range answers { + hash, _ := hex.DecodeString(selection) + if hash != nil && len(hash) == 32 { + optionHashes = append(optionHashes, hash) + } + } + pollUpdate, err := sender.Client.EncryptPollVote(pollMsgInfo, &waProto.PollVoteMessage{ + SelectedOptions: optionHashes, + }) + return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, err +} + func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, error) { + if evt.Type == TypeMSC3881PollResponse || evt.Type == TypeMSC3881V2PollResponse { + return portal.convertMatrixPollVote(ctx, sender, evt) + } content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { return nil, sender, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) @@ -3351,7 +3426,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev msg := &waProto.Message{} var ctxInfo waProto.ContextInfo - replyToID := content.GetReplyTo() + replyToID := content.RelatesTo.GetReplyTo() if len(replyToID) > 0 { replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID) if replyToMsg != nil && !replyToMsg.IsFakeJID() && replyToMsg.Type == database.MsgNormal { @@ -3570,7 +3645,8 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing start := time.Now() ms := metricSender{portal: portal, timings: &timings} - if err := portal.canBridgeFrom(sender, true); err != nil { + allowRelay := evt.Type != TypeMSC3881PollResponse && evt.Type != TypeMSC3881V2PollResponse + if err := portal.canBridgeFrom(sender, allowRelay); err != nil { go ms.sendMessageMetrics(evt, err, "Ignoring", true) return } else if portal.Key.JID == types.StatusBroadcastJID && portal.bridge.Config.Bridge.DisableStatusBroadcastSend {