Generate different deterministic event IDs for captions

This commit is contained in:
Tulir Asokan 2022-09-20 17:31:08 +03:00
parent a112e48467
commit f0401ee81e
2 changed files with 12 additions and 9 deletions

View file

@ -24,10 +24,10 @@ import (
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
@ -453,8 +453,11 @@ func (user *User) EnqueueForwardBackfills(portals []*Portal) {
// endregion // endregion
// region Portal backfilling // region Portal backfilling
func (portal *Portal) deterministicEventID(sender types.JID, messageID types.MessageID) id.EventID { func (portal *Portal) deterministicEventID(sender types.JID, messageID types.MessageID, partName string) id.EventID {
data := fmt.Sprintf("%s/whatsapp/%s/%s", portal.MXID, sender.User, messageID) data := fmt.Sprintf("%s/whatsapp/%s/%s", portal.MXID, sender.User, messageID)
if partName != "" {
data += "/" + partName
}
sum := sha256.Sum256([]byte(data)) sum := sha256.Sum256([]byte(data))
return id.EventID(fmt.Sprintf("$%s:whatsapp.com", base64.RawURLEncoding.EncodeToString(sum[:]))) return id.EventID(fmt.Sprintf("$%s:whatsapp.com", base64.RawURLEncoding.EncodeToString(sum[:])))
} }
@ -636,7 +639,7 @@ func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, i
} }
func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types.MessageInfo, expirationStart uint64, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error { func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types.MessageInfo, expirationStart uint64, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra) mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
if err != nil { if err != nil {
return err return err
} }
@ -644,7 +647,7 @@ func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types
converted.MergeCaption() converted.MergeCaption()
} }
if converted.Caption != nil { if converted.Caption != nil {
captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil) captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
if err != nil { if err != nil {
return err return err
} }
@ -655,8 +658,8 @@ func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types
*infoArray = append(*infoArray, &wrappedInfo{info, database.MsgNormal, converted.Error, converted.MediaKey, expirationStart, converted.ExpiresIn}) *infoArray = append(*infoArray, &wrappedInfo{info, database.MsgNormal, converted.Error, converted.MediaKey, expirationStart, converted.ExpiresIn})
} }
if converted.MultiEvent != nil { if converted.MultiEvent != nil {
for _, subEvtContent := range converted.MultiEvent { for i, subEvtContent := range converted.MultiEvent {
subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil) subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
if err != nil { if err != nil {
return err return err
} }
@ -667,7 +670,7 @@ func (portal *Portal) appendBatchEvents(converted *ConvertedMessage, info *types
return nil return nil
} }
func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}) (*event.Event, error) { func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
wrappedContent := event.Content{ wrappedContent := event.Content{
Parsed: content, Parsed: content,
Raw: extraContent, Raw: extraContent,
@ -681,7 +684,7 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
} }
var eventID id.EventID var eventID id.EventID
if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
eventID = portal.deterministicEventID(info.Sender, info.ID) eventID = portal.deterministicEventID(info.Sender, info.ID, partName)
} }
return &event.Event{ return &event.Event{

View file

@ -1580,7 +1580,7 @@ func (portal *Portal) SetReply(content *event.MessageEventContent, replyTo *Repl
message := portal.bridge.DB.Message.GetByJID(portal.Key, replyTo.MessageID) message := portal.bridge.DB.Message.GetByJID(portal.Key, replyTo.MessageID)
if message == nil || message.IsFakeMXID() { if message == nil || message.IsFakeMXID() {
if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(portal.deterministicEventID(replyTo.Sender, replyTo.MessageID)) content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(portal.deterministicEventID(replyTo.Sender, replyTo.MessageID, ""))
return true return true
} }
return false return false