Store message sender mxid in database

This commit is contained in:
Tulir Asokan 2023-05-26 15:55:53 +03:00
parent 10aa66a128
commit 559ac719a4
5 changed files with 54 additions and 43 deletions

View file

@ -44,27 +44,27 @@ func (mq *MessageQuery) New() *Message {
const ( const (
getAllMessagesQuery = ` getAllMessagesQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 WHERE chat_jid=$1 AND chat_receiver=$2
` `
getMessageByJIDQuery = ` getMessageByJIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3 WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3
` `
getMessageByMXIDQuery = ` getMessageByMXIDQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE mxid=$1 WHERE mxid=$1
` `
getLastMessageInChatQuery = ` getLastMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1 WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp<=$3 AND sent=true ORDER BY timestamp DESC LIMIT 1
` `
getFirstMessageInChatQuery = ` getFirstMessageInChatQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1 WHERE chat_jid=$1 AND chat_receiver=$2 AND sent=true ORDER BY timestamp ASC LIMIT 1
` `
getMessagesBetweenQuery = ` getMessagesBetweenQuery = `
SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
` `
) )
@ -146,14 +146,15 @@ type Message struct {
db *Database db *Database
log log.Logger log log.Logger
Chat PortalKey Chat PortalKey
JID types.MessageID JID types.MessageID
MXID id.EventID MXID id.EventID
Sender types.JID Sender types.JID
Timestamp time.Time SenderMXID id.UserID
Sent bool Timestamp time.Time
Type MessageType Sent bool
Error MessageErrorType Type MessageType
Error MessageErrorType
BroadcastListJID types.JID BroadcastListJID types.JID
} }
@ -168,7 +169,7 @@ func (msg *Message) IsFakeJID() bool {
func (msg *Message) Scan(row dbutil.Scannable) *Message { func (msg *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err) msg.log.Errorln("Database scan failed:", err)
@ -192,9 +193,9 @@ func (msg *Message) Insert(txn dbutil.Execable) {
} }
_, err := txn.Exec(` _, err := txn.Exec(`
INSERT INTO message INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent, type, error, broadcast_list_jid) (chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID) `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
if err != nil { if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
} }

View file

@ -1,4 +1,4 @@
-- v0 -> v55: Latest revision -- v0 -> v56 (compatible with v45+): Latest revision
CREATE TABLE "user" ( CREATE TABLE "user" (
mxid TEXT PRIMARY KEY, mxid TEXT PRIMARY KEY,
@ -70,6 +70,7 @@ CREATE TABLE message (
jid TEXT, jid TEXT,
mxid TEXT UNIQUE, mxid TEXT UNIQUE,
sender TEXT, sender TEXT,
sender_mxid TEXT NOT NULL DEFAULT '',
timestamp BIGINT, timestamp BIGINT,
sent BOOLEAN, sent BOOLEAN,
error error_type, error error_type,

View file

@ -0,0 +1,2 @@
-- v56 (compatible with v45+): Store whether custom contact info has been set for a puppet
ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT '';

View file

@ -45,6 +45,8 @@ type wrappedInfo struct {
Type database.MessageType Type database.MessageType
Error database.MessageErrorType Error database.MessageErrorType
SenderMXID id.UserID
ReactionTarget types.MessageID ReactionTarget types.MessageID
MediaKey []byte MediaKey []byte
@ -268,6 +270,7 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
msg.MXID = resp.EventID msg.MXID = resp.EventID
msg.JID = types.MessageID(resp.EventID) msg.JID = types.MessageID(resp.EventID)
msg.Timestamp = conv.LastMessageTimestamp msg.Timestamp = conv.LastMessageTimestamp
msg.SenderMXID = portal.MainIntent().UserID
msg.Sent = true msg.Sent = true
msg.Type = database.MsgFake msg.Type = database.MsgFake
msg.Insert(nil) msg.Insert(nil)
@ -749,6 +752,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
mainInfo := &wrappedInfo{ mainInfo := &wrappedInfo{
MessageInfo: info, MessageInfo: info,
Type: database.MsgNormal, Type: database.MsgNormal,
SenderMXID: mainEvt.Sender,
Error: converted.Error, Error: converted.Error,
MediaKey: converted.MediaKey, MediaKey: converted.MediaKey,
ExpirationStart: expirationStart, ExpirationStart: expirationStart,
@ -783,6 +787,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
*eventsArray = append(*eventsArray, reactionEvent) *eventsArray = append(*eventsArray, reactionEvent)
*infoArray = append(*infoArray, &wrappedInfo{ *infoArray = append(*infoArray, &wrappedInfo{
MessageInfo: reactionInfo, MessageInfo: reactionInfo,
SenderMXID: reactionEvent.Sender,
ReactionTarget: info.ID, ReactionTarget: info.ID,
Type: database.MsgReaction, Type: database.MsgReaction,
}) })
@ -872,7 +877,7 @@ func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID,
} }
eventID := eventIDs[i] eventID := eventIDs[i]
portal.markHandled(txn, nil, info.MessageInfo, eventID, true, false, info.Type, info.Error) portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, info.Error)
if info.Type == database.MsgReaction { if info.Type == database.MsgReaction {
portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID) portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
} }
@ -896,6 +901,7 @@ func (portal *Portal) sendPostBackfillDummy(lastTimestamp time.Time, insertionEv
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Chat = portal.Key msg.Chat = portal.Key
msg.MXID = resp.EventID msg.MXID = resp.EventID
msg.SenderMXID = portal.MainIntent().UserID
msg.JID = types.MessageID(resp.EventID) msg.JID = types.MessageID(resp.EventID)
msg.Timestamp = lastTimestamp.Add(1 * time.Second) msg.Timestamp = lastTimestamp.Add(1 * time.Second)
msg.Sent = true msg.Sent = true

View file

@ -670,7 +670,7 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec
portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err) portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err)
return return
} }
portal.finishHandling(nil, &evt.Info, resp.EventID, database.MsgUnknown, database.MsgErrDecryptionFailed) portal.finishHandling(nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, database.MsgErrDecryptionFailed)
} }
func (portal *Portal) handleFakeMessage(msg fakeMessage) { func (portal *Portal) handleFakeMessage(msg fakeMessage) {
@ -703,7 +703,7 @@ func (portal *Portal) handleFakeMessage(msg fakeMessage) {
MessageSource: types.MessageSource{ MessageSource: types.MessageSource{
Sender: msg.Sender, Sender: msg.Sender,
}, },
}, resp.EventID, database.MsgFake, database.MsgNoError) }, resp.EventID, intent.UserID, database.MsgFake, database.MsgNoError)
} }
} }
@ -818,7 +818,7 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message) {
} }
} }
if len(eventID) != 0 { if len(eventID) != 0 {
portal.finishHandling(existingMsg, &evt.Info, eventID, dbMsgType, converted.Error) portal.finishHandling(existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, converted.Error)
} }
} else if msgType == "reaction" || msgType == "encrypted reaction" { } else if msgType == "reaction" || msgType == "encrypted reaction" {
if evt.Message.GetEncReactionMessage() != nil { if evt.Message.GetEncReactionMessage() != nil {
@ -863,7 +863,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa
return false return false
} }
func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message { func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, errType database.MessageErrorType) *database.Message {
if msg == nil { if msg == nil {
msg = portal.bridge.DB.Message.New() msg = portal.bridge.DB.Message.New()
msg.Chat = portal.Key msg.Chat = portal.Key
@ -871,6 +871,7 @@ func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message,
msg.MXID = mxid msg.MXID = mxid
msg.Timestamp = info.Timestamp msg.Timestamp = info.Timestamp
msg.Sender = info.Sender msg.Sender = info.Sender
msg.SenderMXID = senderMXID
msg.Sent = isSent msg.Sent = isSent
msg.Type = msgType msg.Type = msgType
msg.Error = errType msg.Error = errType
@ -922,8 +923,8 @@ func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo, msgT
return intent return intent
} }
func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, msgType database.MessageType, errType database.MessageErrorType) { func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, errType database.MessageErrorType) {
portal.markHandled(nil, existing, message, mxid, true, true, msgType, errType) portal.markHandled(nil, existing, message, mxid, senderMXID, true, true, msgType, errType)
portal.sendDeliveryReceipt(mxid) portal.sendDeliveryReceipt(mxid)
var suffix string var suffix string
if errType == database.MsgErrDecryptionFailed { if errType == database.MsgErrDecryptionFailed {
@ -1881,19 +1882,20 @@ func (portal *Portal) MainIntent() *appservice.IntentAPI {
return portal.bridge.Bot return portal.bridge.Bot
} }
func (portal *Portal) addReplyMention(content *event.MessageEventContent, sender types.JID) { func (portal *Portal) addReplyMention(content *event.MessageEventContent, sender types.JID, senderMXID id.UserID) {
if content.Mentions == nil { if content.Mentions == nil || (sender.IsEmpty() && senderMXID == "") {
return return
} }
var mxid id.UserID if senderMXID == "" {
if user := portal.bridge.GetUserByJID(sender); user != nil { if user := portal.bridge.GetUserByJID(sender); user != nil {
mxid = user.MXID senderMXID = user.MXID
} else { } else {
puppet := portal.bridge.GetPuppetByJID(sender) puppet := portal.bridge.GetPuppetByJID(sender)
mxid = puppet.MXID senderMXID = puppet.MXID
}
} }
if slices.Contains(content.Mentions.UserIDs, mxid) { if senderMXID != "" && !slices.Contains(content.Mentions.UserIDs, senderMXID) {
content.Mentions.UserIDs = append(content.Mentions.UserIDs, mxid) content.Mentions.UserIDs = append(content.Mentions.UserIDs, senderMXID)
} }
} }
@ -1925,13 +1927,12 @@ func (portal *Portal) SetReply(content *event.MessageEventContent, replyTo *Repl
if message == nil || message.IsFakeMXID() { if message == nil || message.IsFakeMXID() {
if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { if isBackfill && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(targetPortal.deterministicEventID(replyTo.Sender, replyTo.MessageID, "")) content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(targetPortal.deterministicEventID(replyTo.Sender, replyTo.MessageID, ""))
portal.addReplyMention(content, replyTo.Sender) portal.addReplyMention(content, replyTo.Sender, "")
return true return true
} }
return false return false
} }
// TODO store sender mxid in db message portal.addReplyMention(content, message.Sender, message.SenderMXID)
portal.addReplyMention(content, message.Sender)
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(message.MXID) content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(message.MXID)
if portal.bridge.Config.Bridge.DisableReplyFallbacks { if portal.bridge.Config.Bridge.DisableReplyFallbacks {
return true return true
@ -1973,7 +1974,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user *
if err != nil { if err != nil {
portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err) portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err)
} }
portal.finishHandling(existingMsg, info, resp.EventID, database.MsgReaction, database.MsgNoError) portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError)
existing.Delete() existing.Delete()
} else { } else {
target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID) target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID)
@ -1994,7 +1995,7 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user *
return return
} }
portal.finishHandling(existingMsg, info, resp.EventID, database.MsgReaction, database.MsgNoError) portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, database.MsgNoError)
portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID) portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID)
} }
} }
@ -4134,7 +4135,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing
} }
info := portal.generateMessageInfo(sender) info := portal.generateMessageInfo(sender)
if dbMsg == nil { if dbMsg == nil {
dbMsg = portal.markHandled(nil, nil, info, evt.ID, false, true, dbMsgType, database.MsgNoError) dbMsg = portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, database.MsgNoError)
} else { } else {
info.ID = dbMsg.JID info.ID = dbMsg.JID
} }
@ -4189,7 +4190,7 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error
return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID) return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID)
} }
info := portal.generateMessageInfo(sender) info := portal.generateMessageInfo(sender)
dbMsg := portal.markHandled(nil, nil, info, evt.ID, false, true, database.MsgReaction, database.MsgNoError) dbMsg := portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, database.MsgNoError)
portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID) portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID)
portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID) portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID)
resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp) resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp)