From 8b1308595f7f2f3b775a2b65249de622dc70a443 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 4 Sep 2023 19:28:15 +0300 Subject: [PATCH] Add support for collecting incoming galleries into single event --- config/bridge.go | 1 + config/upgrade.go | 1 + database/message.go | 30 ++++++++++--- example-config.yaml | 2 + historysync.go | 2 +- portal.go | 102 +++++++++++++++++++++++++++++++++++++++----- 6 files changed, 119 insertions(+), 19 deletions(-) diff --git a/config/bridge.go b/config/bridge.go index 4e5a8b0..689c263 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -111,6 +111,7 @@ type BridgeConfig struct { FederateRooms bool `yaml:"federate_rooms"` URLPreviews bool `yaml:"url_previews"` CaptionInMessage bool `yaml:"caption_in_message"` + BeeperGalleries bool `yaml:"beeper_galleries"` ExtEvPolls bool `yaml:"extev_polls"` CrossRoomReplies bool `yaml:"cross_room_replies"` DisableReplyFallbacks bool `yaml:"disable_reply_fallbacks"` diff --git a/config/upgrade.go b/config/upgrade.go index b405130..f307651 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -103,6 +103,7 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "crash_on_stream_replaced") helper.Copy(up.Bool, "bridge", "url_previews") helper.Copy(up.Bool, "bridge", "caption_in_message") + helper.Copy(up.Bool, "bridge", "beeper_galleries") if intPolls, ok := helper.Get(up.Int, "bridge", "extev_polls"); ok { val := "false" if intPolls != "0" { diff --git a/database/message.go b/database/message.go index 8f35de8..dbd7e9c 100644 --- a/database/message.go +++ b/database/message.go @@ -19,6 +19,7 @@ package database import ( "database/sql" "errors" + "fmt" "strings" "time" @@ -133,12 +134,13 @@ const ( type MessageType string const ( - MsgUnknown MessageType = "" - MsgFake MessageType = "fake" - MsgNormal MessageType = "message" - MsgReaction MessageType = "reaction" - MsgEdit MessageType = "edit" - MsgMatrixPoll MessageType = "matrix-poll" + MsgUnknown MessageType = "" + MsgFake MessageType = "fake" + MsgNormal MessageType = "message" + MsgReaction MessageType = "reaction" + MsgEdit MessageType = "edit" + MsgMatrixPoll MessageType = "matrix-poll" + MsgBeeperGallery MessageType = "beeper-gallery" ) type Message struct { @@ -155,6 +157,8 @@ type Message struct { Type MessageType Error MessageErrorType + GalleryPart int + BroadcastListJID types.JID } @@ -166,6 +170,8 @@ func (msg *Message) IsFakeJID() bool { return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID) } +const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s" + func (msg *Message) Scan(row dbutil.Scannable) *Message { var ts int64 err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) @@ -175,6 +181,12 @@ func (msg *Message) Scan(row dbutil.Scannable) *Message { } return nil } + if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") { + _, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID) + if err != nil { + msg.log.Errorln("Parsing gallery MXID failed:", err) + } + } if ts != 0 { msg.Timestamp = time.Unix(ts, 0) } @@ -190,11 +202,15 @@ func (msg *Message) Insert(txn dbutil.Execable) { if msg.Sender.IsEmpty() { sender = "" } + mxid := msg.MXID.String() + if msg.GalleryPart != 0 { + mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid) + } _, err := txn.Exec(` INSERT INTO message (chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID) + `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID) if err != nil { msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) } diff --git a/example-config.yaml b/example-config.yaml index be56be3..33e5a79 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -300,6 +300,8 @@ bridge: # Send captions in the same message as images. This will send data compatible with both MSC2530 and MSC3552. # This is currently not supported in most clients. caption_in_message: false + # Send galleries as a single event? This is not an MSC (yet). + beeper_galleries: false # Should polls be sent using MSC3381 event types? extev_polls: false # Should cross-chat replies from WhatsApp be bridged? Most servers and clients don't support this. diff --git a/historysync.go b/historysync.go index 175746e..9e7ef75 100644 --- a/historysync.go +++ b/historysync.go @@ -850,7 +850,7 @@ func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, } eventID := eventIDs[i] - portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, info.Error) + portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error) if info.Type == database.MsgReaction { portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID) } diff --git a/portal.go b/portal.go index 7b81ccc..4d9004e 100644 --- a/portal.go +++ b/portal.go @@ -284,12 +284,50 @@ type Portal struct { mediaErrorCache map[types.MessageID]*FailedMediaMeta + galleryCache []*event.MessageEventContent + galleryCacheRootEvent id.EventID + galleryCacheStart time.Time + galleryCacheReplyTo *ReplyInfo + galleryCacheSender types.JID + currentlySleepingToDelete sync.Map relayUser *User parentPortal *Portal } +const GalleryMaxTime = 10 * time.Minute + +func (portal *Portal) stopGallery() { + if portal.galleryCache != nil { + portal.galleryCache = nil + portal.galleryCacheSender = types.EmptyJID + portal.galleryCacheReplyTo = nil + portal.galleryCacheStart = time.Time{} + portal.galleryCacheRootEvent = "" + } +} + +func (portal *Portal) startGallery(evt *events.Message, msg *ConvertedMessage) { + portal.galleryCache = []*event.MessageEventContent{msg.Content} + portal.galleryCacheSender = evt.Info.Sender.ToNonAD() + portal.galleryCacheReplyTo = msg.ReplyTo + portal.galleryCacheStart = time.Now() +} + +func (portal *Portal) extendGallery(msg *ConvertedMessage) int { + portal.galleryCache = append(portal.galleryCache, msg.Content) + msg.Content = &event.MessageEventContent{ + MsgType: event.MsgBeeperGallery, + Body: "Sent a gallery", + BeeperGalleryImages: portal.galleryCache, + } + msg.Content.SetEdit(portal.galleryCacheRootEvent) + // Don't set the gallery images in the edit fallback + msg.Content.BeeperGalleryImages = nil + return len(portal.galleryCache) - 1 +} + var ( _ bridge.Portal = (*Portal)(nil) _ bridge.ReadReceiptHandlingPortal = (*Portal)(nil) @@ -319,8 +357,10 @@ func (portal *Portal) handleWhatsAppMessageLoopItem(msg PortalMessage) { case msg.receipt != nil: portal.handleReceipt(msg.receipt, msg.source) case msg.undecryptable != nil: + portal.stopGallery() portal.handleUndecryptableMessage(msg.source, msg.undecryptable) case msg.fake != nil: + portal.stopGallery() msg.fake.ID = "FAKE::" + msg.fake.ID portal.handleFakeMessage(*msg.fake) default: @@ -746,7 +786,7 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err) return } - portal.finishHandling(nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, database.MsgErrDecryptionFailed) + portal.finishHandling(nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, 0, database.MsgErrDecryptionFailed) } func (portal *Portal) handleFakeMessage(msg fakeMessage) { @@ -784,7 +824,7 @@ func (portal *Portal) handleFakeMessage(msg fakeMessage) { MessageSource: types.MessageSource{ Sender: msg.Sender, }, - }, resp.EventID, intent.UserID, database.MsgFake, database.MsgNoError) + }, resp.EventID, intent.UserID, database.MsgFake, 0, database.MsgNoError) } } @@ -841,6 +881,17 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica } converted := portal.convertMessage(intent, source, &evt.Info, evt.Message, false) if converted != nil { + isGalleriable := portal.bridge.Config.Bridge.BeeperGalleries && + (evt.Message.ImageMessage != nil || evt.Message.VideoMessage != nil) && + (portal.galleryCache == nil || + (evt.Info.Sender.ToNonAD() == portal.galleryCacheSender && + converted.ReplyTo.Equals(portal.galleryCacheReplyTo) && + time.Since(portal.galleryCacheStart) < GalleryMaxTime)) && + // Captions aren't allowed in galleries (this needs to be checked before the caption is merged) + converted.Caption == nil && + // Images can't be edited + editTargetMsg == nil + if !historical && portal.IsPrivateChat() && evt.Info.Sender.Device == 0 && converted.ExpiresIn > 0 && portal.ExpirationTime == 0 { portal.zlog.Info(). Str("timer", converted.ExpiresIn.String()). @@ -871,6 +922,20 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica dbMsgType = database.MsgEdit converted.Content.SetEdit(editTargetMsg.MXID) } + galleryStarted := false + var galleryPart int + if isGalleriable { + if portal.galleryCache == nil { + portal.startGallery(evt, converted) + galleryStarted = true + } else { + galleryPart = portal.extendGallery(converted) + dbMsgType = database.MsgBeeperGallery + } + } else if editTargetMsg == nil { + // Stop collecting a gallery (except if it's an edit) + portal.stopGallery() + } resp, err := portal.sendMessage(converted.Intent, converted.Type, converted.Content, converted.Extra, evt.Info.Timestamp.UnixMilli()) if err != nil { portal.log.Errorfln("Failed to send %s to Matrix: %v", msgID, err) @@ -880,6 +945,11 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica } eventID = resp.EventID lastEventID = eventID + if galleryStarted { + portal.galleryCacheRootEvent = eventID + } else if galleryPart != 0 { + eventID = portal.galleryCacheRootEvent + } } // TODO figure out how to handle captions with undecryptable messages turning decryptable if converted.Caption != nil && existingMsg == nil && editTargetMsg == nil { @@ -912,7 +982,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica } } if len(eventID) != 0 { - portal.finishHandling(existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, converted.Error) + portal.finishHandling(existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, galleryPart, converted.Error) } } else if msgType == "reaction" || msgType == "encrypted reaction" { if evt.Message.GetEncReactionMessage() != nil { @@ -957,12 +1027,13 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa return false } -func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message { +func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) *database.Message { if msg == nil { msg = portal.bridge.DB.Message.New() msg.Chat = portal.Key msg.JID = info.ID msg.MXID = mxid + msg.GalleryPart = galleryPart msg.Timestamp = info.Timestamp msg.Sender = info.Sender msg.SenderMXID = senderMXID @@ -1017,8 +1088,8 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo, msgT return intent } -func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, errType database.MessageErrorType) { - portal.markHandled(nil, existing, message, mxid, senderMXID, true, true, msgType, errType) +func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) { + portal.markHandled(nil, existing, message, mxid, senderMXID, true, true, msgType, galleryPart, errType) portal.sendDeliveryReceipt(mxid) var suffix string if errType == database.MsgErrDecryptionFailed { @@ -2100,7 +2171,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user * if err != nil { portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err) } - portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError) + portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) existing.Delete() } else { target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID) @@ -2121,7 +2192,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user * return } - portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError) + portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID) } } @@ -2209,6 +2280,15 @@ type ReplyInfo struct { Sender types.JID } +func (r *ReplyInfo) Equals(other *ReplyInfo) bool { + if r == nil { + return other == nil + } else if other == nil { + return false + } + return r.MessageID == other.MessageID && r.Chat == other.Chat && r.Sender == other.Sender +} + func (r ReplyInfo) MarshalZerologObject(e *zerolog.Event) { e.Str("message_id", r.MessageID) e.Str("chat_jid", r.Chat.String()) @@ -3927,7 +4007,7 @@ func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.C replyToID := relatesTo.GetReplyTo() if len(replyToID) > 0 { replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID) - if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll) { + if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll || replyToMsg.Type == database.MsgBeeperGallery) { ctxInfo.StanzaId = &replyToMsg.JID ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String()) // Using blank content here seems to work fine on all official WhatsApp apps. @@ -4283,7 +4363,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing } info := portal.generateMessageInfo(sender) if dbMsg == nil { - dbMsg = portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, database.MsgNoError) + dbMsg = portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, 0, database.MsgNoError) } else { info.ID = dbMsg.JID } @@ -4338,7 +4418,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID) } info := portal.generateMessageInfo(sender) - dbMsg := portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, database.MsgNoError) + dbMsg := portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, 0, database.MsgNoError) portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID) portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID) resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)