Update dependencies and lots of code

* Bump minimum Go version to 1.21
* Add contexts everywhere
* Switch database code to new dbutil patterns
* Finish switching away from maulogger
This commit is contained in:
Tulir Asokan 2024-03-11 22:27:10 +02:00
parent f8a22aab06
commit 103bfc31c6
37 changed files with 3423 additions and 2918 deletions

View file

@ -8,7 +8,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
go-version: ["1.20", "1.21"] go-version: ["1.21", "1.22"]
name: Lint ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

View file

@ -15,6 +15,6 @@ repos:
- id: go-vet-repo-mod - id: go-vet-repo-mod
- repo: https://github.com/beeper/pre-commit-go - repo: https://github.com/beeper/pre-commit-go
rev: v0.2.2 rev: v0.3.1
hooks: hooks:
- id: zerolog-ban-msgf - id: zerolog-ban-msgf

View file

@ -1,3 +1,9 @@
# v0.10.6 (unreleased)
* Bumped minimum Go version to 1.21.
* Added 8-letter code pairing support to provisioning API.
* Added more bugs to fix later.
# v0.10.5 (2023-12-16) # v0.10.5 (2023-12-16)
* Added support for sending media to channels. * Added support for sending media to channels.

View file

@ -22,7 +22,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
log "maunium.net/go/maulogger/v2" "github.com/rs/zerolog"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@ -30,7 +30,7 @@ type AnalyticsClient struct {
url string url string
key string key string
userID string userID string
log log.Logger log zerolog.Logger
client http.Client client http.Client
} }
@ -88,9 +88,9 @@ func (sc *AnalyticsClient) Track(userID id.UserID, event string, properties ...m
props["bridge"] = "whatsapp" props["bridge"] = "whatsapp"
err := sc.trackSync(userID, event, props) err := sc.trackSync(userID, event, props)
if err != nil { if err != nil {
sc.log.Errorfln("Error tracking %s: %v", event, err) sc.log.Err(err).Str("event", event).Msg("Error tracking event")
} else { } else {
sc.log.Debugln("Tracked", event) sc.log.Debug().Str("event", event).Msg("Tracked event")
} }
}() }()
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan, Sumner Evans // Copyright (C) 2024 Tulir Asokan, Sumner Evans
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,22 +17,21 @@
package main package main
import ( import (
"context"
"time" "time"
log "maunium.net/go/maulogger/v2" "github.com/rs/zerolog"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
) )
type BackfillQueue struct { type BackfillQueue struct {
BackfillQuery *database.BackfillQuery BackfillQuery *database.BackfillTaskQuery
reCheckChannels []chan bool reCheckChannels []chan bool
log log.Logger
} }
func (bq *BackfillQueue) ReCheck() { func (bq *BackfillQueue) ReCheck() {
bq.log.Infofln("Sending re-checks to %d channels", len(bq.reCheckChannels))
for _, channel := range bq.reCheckChannels { for _, channel := range bq.reCheckChannels {
go func(c chan bool) { go func(c chan bool) {
c <- true c <- true
@ -40,12 +39,19 @@ func (bq *BackfillQueue) ReCheck() {
} }
} }
func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill { func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.BackfillTask {
for { for {
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(userID, waitForBackfillTypes) { if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(ctx, userID, waitForBackfillTypes) {
// check for immediate when dealing with deferred // check for immediate when dealing with deferred
if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil { if backfill, err := bq.BackfillQuery.GetNext(ctx, userID, backfillTypes); err != nil {
backfill.MarkDispatched() zerolog.Ctx(ctx).Err(err).Msg("Failed to get next backfill task")
} else if backfill != nil {
err = backfill.MarkDispatched(ctx)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Int("queue_id", backfill.QueueID).
Msg("Failed to mark backfill task as dispatched")
}
return backfill return backfill
} }
} }
@ -58,38 +64,73 @@ func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []datab
} }
func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) { func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) {
log := user.zlog.With().
Str("action", "backfill request loop").
Any("types", backfillTypes).
Logger()
ctx := log.WithContext(context.TODO())
reCheckChannel := make(chan bool) reCheckChannel := make(chan bool)
user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel) user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel)
for { for {
req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel) req := user.BackfillQueue.GetNextBackfill(ctx, user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
user.log.Infofln("Handling backfill request %s", req) log.Info().Any("backfill_request", req).Msg("Handling backfill request")
log := log.With().
Int("queue_id", req.QueueID).
Stringer("portal_jid", req.Portal.JID).
Logger()
ctx := log.WithContext(ctx)
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, *req.Portal) conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, req.Portal)
if conv == nil { if err != nil {
user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String()) log.Err(err).Msg("Failed to get conversation data for backfill request")
req.MarkDone() continue
} else if conv == nil {
log.Debug().Msg("Couldn't find conversation data for backfill request")
err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after data was not found")
}
continue continue
} }
portal := user.GetPortalByJID(conv.PortalKey.JID) portal := user.GetPortalByJID(conv.PortalKey.JID)
// Update the client store with basic chat settings. // Update the client store with basic chat settings.
if conv.MuteEndTime.After(time.Now()) { if conv.MuteEndTime.After(time.Now()) {
user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime) err = user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
if err != nil {
log.Err(err).Msg("Failed to save muted until time from conversation data")
}
} }
if conv.Archived { if conv.Archived {
user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true) err = user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save archived state from conversation data")
}
} }
if conv.Pinned > 0 { if conv.Pinned > 0 {
user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true) err = user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save pinned state from conversation data")
}
} }
if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration { if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
log.Debug().
Uint32("old_time", portal.ExpirationTime).
Uint32("new_time", *conv.EphemeralExpiration).
Msg("Updating portal ephemeral expiration time")
portal.ExpirationTime = *conv.EphemeralExpiration portal.ExpirationTime = *conv.EphemeralExpiration
portal.Update(nil) err = portal.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal after updating expiration time")
}
} }
user.backfillInChunks(req, conv, portal) user.backfillInChunks(ctx, req, conv, portal)
req.MarkDone() err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after backfilling")
}
} }
} }

View file

@ -93,7 +93,7 @@ func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Requ
remote = remote.Fill(user) remote = remote.Fill(user)
resp.RemoteStates[remote.RemoteID] = remote resp.RemoteStates[remote.RemoteID] = remote
} }
user.log.Debugfln("Responding bridge state in bridge status endpoint: %+v", resp) user.zlog.Debug().Any("response_data", &resp).Msg("Responding bridge state in bridge status endpoint")
jsonResponse(w, http.StatusOK, &resp) jsonResponse(w, http.StatusOK, &resp)
if len(resp.RemoteStates) > 0 { if len(resp.RemoteStates) > 0 {
user.BridgeState.SetPrev(remote) user.BridgeState.SetPrev(remote)

View file

@ -29,6 +29,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/rs/zerolog"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@ -37,7 +38,6 @@ import (
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@ -117,7 +117,10 @@ func fnSetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else { } else {
ce.Portal.RelayUserID = ce.User.MXID ce.Portal.RelayUserID = ce.User.MXID
ce.Portal.Update(nil) err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting relay user")
}
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account") ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
} }
} }
@ -139,7 +142,10 @@ func fnUnsetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else { } else {
ce.Portal.RelayUserID = "" ce.Portal.RelayUserID = ""
ce.Portal.Update(nil) err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after clearing relay user")
}
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room") ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
} }
} }
@ -246,7 +252,7 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to join group: %v", err) ce.Reply("Failed to join group: %v", err)
return return
} }
ce.Log.Debugln("%s successfully joined group %s", ce.User.MXID, jid) ce.ZLog.Debug().Stringer("group_jid", jid).Msg("User successfully joined WhatsApp group with link")
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid) ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
} else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) { } else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) {
info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0]) info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0])
@ -259,14 +265,14 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to follow channel: %v", err) ce.Reply("Failed to follow channel: %v", err)
return return
} }
ce.Log.Debugln("%s successfully followed channel %s", ce.User.MXID, info.ID) ce.ZLog.Debug().Stringer("channel_jid", info.ID).Msg("User successfully followed WhatsApp channel with link")
ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID) ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID)
} else { } else {
ce.Reply("That doesn't look like a WhatsApp invite link") ce.Reply("That doesn't look like a WhatsApp invite link")
} }
} }
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) { func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage, error) {
var data json.RawMessage var data json.RawMessage
if evt.Type != event.EventEncrypted { if evt.Type != event.EventEncrypted {
data = evt.Content.VeryRaw data = evt.Content.VeryRaw
@ -275,7 +281,7 @@ func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, e
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, err return nil, err
} }
decrypted, err := crypto.Decrypt(evt) decrypted, err := ce.Bridge.Crypto.Decrypt(ce.Ctx, evt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -311,11 +317,11 @@ var cmdAccept = &commands.FullHandler{
func fnAccept(ce *WrappedCommandEvent) { func fnAccept(ce *WrappedCommandEvent) {
if len(ce.ReplyTo) == 0 { if len(ce.ReplyTo) == 0 {
ce.Reply("You must reply to a group invite message when using this command.") ce.Reply("You must reply to a group invite message when using this command.")
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.RoomID, ce.ReplyTo); err != nil { } else if evt, err := ce.Portal.MainIntent().GetEvent(ce.Ctx, ce.RoomID, ce.ReplyTo); err != nil {
ce.Log.Errorln("Failed to get event %s to handle !wa accept command: %v", ce.ReplyTo, err) ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to get reply target event to handle !wa accept command")
ce.Reply("Failed to get reply event") ce.Reply("Failed to get reply event")
} else if rawContent, err := tryDecryptEvent(ce.Bridge.Crypto, evt); err != nil { } else if rawContent, err := tryDecryptEvent(ce, evt); err != nil {
ce.Log.Errorln("Failed to decrypt event %s to handle !wa accept command: %v", ce.ReplyTo, err) ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to decrypt reply target event to handle !wa accept command")
ce.Reply("Failed to decrypt reply event") ce.Reply("Failed to decrypt reply event")
} else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil { } else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil {
ce.Reply("That doesn't look like a group invite message.") ce.Reply("That doesn't look like a group invite message.")
@ -344,16 +350,16 @@ func fnCreate(ce *WrappedCommandEvent) {
return return
} }
members, err := ce.Bot.JoinedMembers(ce.RoomID) members, err := ce.Bot.JoinedMembers(ce.Ctx, ce.RoomID)
if err != nil { if err != nil {
ce.Reply("Failed to get room members: %v", err) ce.Reply("Failed to get room members: %v", err)
return return
} }
var roomNameEvent event.RoomNameEventContent var roomNameEvent event.RoomNameEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateRoomName, "", &roomNameEvent) err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateRoomName, "", &roomNameEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) { if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.Log.Errorln("Failed to get room name to create group:", err) ce.ZLog.Err(err).Msg("Failed to get room name to create group")
ce.Reply("Failed to get room name") ce.Reply("Failed to get room name")
return return
} else if len(roomNameEvent.Name) == 0 { } else if len(roomNameEvent.Name) == 0 {
@ -362,15 +368,17 @@ func fnCreate(ce *WrappedCommandEvent) {
} }
var encryptionEvent event.EncryptionEventContent var encryptionEvent event.EncryptionEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateEncryption, "", &encryptionEvent) err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateEncryption, "", &encryptionEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) { if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room encryption status to create group")
ce.Reply("Failed to get room encryption status") ce.Reply("Failed to get room encryption status")
return return
} }
var createEvent event.CreateEventContent var createEvent event.CreateEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateCreate, "", &createEvent) err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateCreate, "", &createEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) { if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room create event to create group")
ce.Reply("Failed to get room create event") ce.Reply("Failed to get room create event")
return return
} }
@ -395,7 +403,11 @@ func fnCreate(ce *WrappedCommandEvent) {
// TODO check m.space.parent to create rooms directly in communities // TODO check m.space.parent to create rooms directly in communities
messageID := ce.User.Client.GenerateMessageID() messageID := ce.User.Client.GenerateMessageID()
ce.Log.Infofln("Creating group for %s with name %s and participants %+v (create key: %s)", ce.RoomID, roomNameEvent.Name, participants, messageID) ce.ZLog.Info().
Str("room_name", roomNameEvent.Name).
Any("participants", participants).
Str("create_key", messageID).
Msg("Creating WhatsApp group for Matrix room")
ce.User.createKeyDedup = messageID ce.User.createKeyDedup = messageID
resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{ resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{
CreateKey: messageID, CreateKey: messageID,
@ -409,21 +421,25 @@ func fnCreate(ce *WrappedCommandEvent) {
ce.Reply("Failed to create group: %v", err) ce.Reply("Failed to create group: %v", err)
return return
} }
ce.ZLog.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("group_jid", resp.JID.String())
})
portal := ce.User.GetPortalByJID(resp.JID) portal := ce.User.GetPortalByJID(resp.JID)
portal.roomCreateLock.Lock() portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock() defer portal.roomCreateLock.Unlock()
if len(portal.MXID) != 0 { if len(portal.MXID) != 0 {
portal.log.Warnln("Detected race condition in room creation") ce.ZLog.Warn().Msg("Detected race condition in room creation")
// TODO race condition, clean up the old room // TODO race condition, clean up the old room
} }
portal.MXID = ce.RoomID portal.MXID = ce.RoomID
portal.updateLogger()
portal.Name = roomNameEvent.Name portal.Name = roomNameEvent.Name
portal.IsParent = resp.IsParent portal.IsParent = resp.IsParent
portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1 portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1
if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default { if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default {
_, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent()) _, err = portal.MainIntent().SendStateEvent(ce.Ctx, portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil { if err != nil {
portal.log.Warnln("Failed to enable encryption in room:", err) ce.ZLog.Err(err).Msg("Failed to enable encryption in room")
if errors.Is(err, mautrix.MForbidden) { if errors.Is(err, mautrix.MForbidden) {
ce.Reply("I don't seem to have permission to enable encryption in this room.") ce.Reply("I don't seem to have permission to enable encryption in this room.")
} else { } else {
@ -433,8 +449,11 @@ func fnCreate(ce *WrappedCommandEvent) {
portal.Encrypted = true portal.Encrypted = true
} }
portal.Update(nil) err = portal.Update(ce.Ctx)
portal.UpdateBridgeInfo() if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after creating group")
}
portal.UpdateBridgeInfo(ce.Ctx)
ce.User.createKeyDedup = "" ce.User.createKeyDedup = ""
ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID) ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)
@ -512,7 +531,7 @@ func fnLogin(ce *WrappedCommandEvent) {
} }
} }
if qrEventID != "" { if qrEventID != "" {
_, _ = ce.Bot.RedactEvent(ce.RoomID, qrEventID) _, _ = ce.Bot.RedactEvent(ce.Ctx, ce.RoomID, qrEventID)
} }
} }
@ -529,9 +548,9 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
if len(prevEvent) != 0 { if len(prevEvent) != 0 {
content.SetEdit(prevEvent) content.SetEdit(prevEvent)
} }
resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content) resp, err := ce.Bot.SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, &content)
if err != nil { if err != nil {
user.log.Errorln("Failed to send edited QR code to user:", err) ce.ZLog.Err(err).Msg("Failed to send edited QR code to user")
} else if len(prevEvent) == 0 { } else if len(prevEvent) == 0 {
prevEvent = resp.EventID prevEvent = resp.EventID
} }
@ -541,16 +560,16 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) { func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) {
qrCode, err := qrcode.Encode(code, qrcode.Low, 256) qrCode, err := qrcode.Encode(code, qrcode.Low, 256)
if err != nil { if err != nil {
user.log.Errorln("Failed to encode QR code:", err) ce.ZLog.Err(err).Msg("Failed to encode QR code")
ce.Reply("Failed to encode QR code: %v", err) ce.Reply("Failed to encode QR code: %v", err)
return id.ContentURI{}, false return id.ContentURI{}, false
} }
bot := user.bridge.AS.BotClient() bot := user.bridge.AS.BotClient()
resp, err := bot.UploadBytes(qrCode, "image/png") resp, err := bot.UploadBytes(ce.Ctx, qrCode, "image/png")
if err != nil { if err != nil {
user.log.Errorln("Failed to upload QR code:", err) ce.ZLog.Err(err).Msg("Failed to upload QR code")
ce.Reply("Failed to upload QR code: %v", err) ce.Reply("Failed to upload QR code: %v", err)
return id.ContentURI{}, false return id.ContentURI{}, false
} }
@ -578,14 +597,14 @@ func fnLogout(ce *WrappedCommandEvent) {
puppet.ClearCustomMXID() puppet.ClearCustomMXID()
err := ce.User.Client.Logout() err := ce.User.Client.Logout()
if err != nil { if err != nil {
ce.User.log.Warnln("Error while logging out:", err) ce.ZLog.Err(err).Msg("Unknown error while logging out")
ce.Reply("Unknown error while logging out: %v", err) ce.Reply("Unknown error while logging out: %v", err)
return return
} }
ce.User.Session = nil ce.User.Session = nil
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection() ce.User.DeleteConnection()
ce.User.DeleteSession() ce.User.DeleteSession(ce.Ctx)
ce.Reply("Logged out successfully.") ce.Reply("Logged out successfully.")
} }
@ -620,10 +639,13 @@ func fnTogglePresence(ce *WrappedCommandEvent) {
if ce.User.IsLoggedIn() { if ce.User.IsLoggedIn() {
err := ce.User.Client.SendPresence(newPresence) err := ce.User.Client.SendPresence(newPresence)
if err != nil { if err != nil {
ce.User.log.Warnln("Failed to set presence:", err) ce.ZLog.Err(err).Msg("Failed to send presence to WhatsApp")
} }
} }
customPuppet.Update() err := customPuppet.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save puppet after toggling presence")
}
} }
var cmdDeleteSession = &commands.FullHandler{ var cmdDeleteSession = &commands.FullHandler{
@ -642,7 +664,7 @@ func fnDeleteSession(ce *WrappedCommandEvent) {
} }
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection() ce.User.DeleteConnection()
ce.User.DeleteSession() ce.User.DeleteSession(ce.Ctx)
ce.Reply("Session information purged") ce.Reply("Session information purged")
} }
@ -716,19 +738,19 @@ func fnPing(ce *WrappedCommandEvent) {
} }
} }
func canDeletePortal(portal *Portal, userID id.UserID) bool { func canDeletePortal(ce *WrappedCommandEvent, portal *Portal) bool {
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
return false return false
} }
members, err := portal.MainIntent().JoinedMembers(portal.MXID) members, err := portal.MainIntent().JoinedMembers(ce.Ctx, portal.MXID)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to get joined members to check if portal can be deleted by %s: %v", userID, err) ce.ZLog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get joined members to check if portal can be deleted by user")
return false return false
} }
for otherUser := range members.Joined { for otherUser := range members.Joined {
_, isPuppet := portal.bridge.ParsePuppetMXID(otherUser) _, isPuppet := portal.bridge.ParsePuppetMXID(otherUser)
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == userID { if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == ce.User.MXID {
continue continue
} }
user := portal.bridge.GetUserByMXID(otherUser) user := portal.bridge.GetUserByMXID(otherUser)
@ -750,14 +772,14 @@ var cmdDeletePortal = &commands.FullHandler{
} }
func fnDeletePortal(ce *WrappedCommandEvent) { func fnDeletePortal(ce *WrappedCommandEvent) {
if !ce.User.Admin && !canDeletePortal(ce.Portal, ce.User.MXID) { if !ce.User.Admin && !canDeletePortal(ce, ce.Portal) {
ce.Reply("Only bridge admins can delete portals with other Matrix users") ce.Reply("Only bridge admins can delete portals with other Matrix users")
return return
} }
ce.Portal.log.Infoln(ce.User.MXID, "requested deletion of portal.") ce.ZLog.Info().Msg("User requested deletion of current portal")
ce.Portal.Delete() ce.Portal.Delete(ce.Ctx)
ce.Portal.Cleanup(false) ce.Portal.Cleanup(ce.Ctx, false)
} }
var cmdDeleteAllPortals = &commands.FullHandler{ var cmdDeleteAllPortals = &commands.FullHandler{
@ -778,7 +800,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
} else { } else {
portalsToDelete = portals[:0] portalsToDelete = portals[:0]
for _, portal := range portals { for _, portal := range portals {
if canDeletePortal(portal, ce.User.MXID) { if canDeletePortal(ce, portal) {
portalsToDelete = append(portalsToDelete, portal) portalsToDelete = append(portalsToDelete, portal)
} }
} }
@ -790,7 +812,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
leave := func(portal *Portal) { leave := func(portal *Portal) {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
_, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{ _, _ = portal.MainIntent().KickUser(ce.Ctx, portal.MXID, &mautrix.ReqKickUser{
Reason: "Deleting portal", Reason: "Deleting portal",
UserID: ce.User.MXID, UserID: ce.User.MXID,
}) })
@ -801,21 +823,21 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
intent := customPuppet.CustomIntent() intent := customPuppet.CustomIntent()
leave = func(portal *Portal) { leave = func(portal *Portal) {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
_, _ = intent.LeaveRoom(portal.MXID) _, _ = intent.LeaveRoom(ce.Ctx, portal.MXID)
_, _ = intent.ForgetRoom(portal.MXID) _, _ = intent.ForgetRoom(ce.Ctx, portal.MXID)
} }
} }
} }
ce.Reply("Found %d portals, deleting...", len(portalsToDelete)) ce.Reply("Found %d portals, deleting...", len(portalsToDelete))
for _, portal := range portalsToDelete { for _, portal := range portalsToDelete {
portal.Delete() portal.Delete(ce.Ctx)
leave(portal) leave(portal)
} }
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.") ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
go func() { go func() {
for _, portal := range portalsToDelete { for _, portal := range portalsToDelete {
portal.Cleanup(false) portal.Cleanup(ce.Ctx, false)
} }
ce.Reply("Finished background cleanup of deleted portal rooms.") ce.Reply("Finished background cleanup of deleted portal rooms.")
}() }()
@ -882,7 +904,7 @@ func fnList(ce *WrappedCommandEvent) {
} }
var err error var err error
page := 1 page := 1
max := 100 maxPerPage := 100
if len(ce.Args) > 1 { if len(ce.Args) > 1 {
page, err = strconv.Atoi(ce.Args[1]) page, err = strconv.Atoi(ce.Args[1])
if err != nil || page <= 0 { if err != nil || page <= 0 {
@ -891,11 +913,11 @@ func fnList(ce *WrappedCommandEvent) {
} }
} }
if len(ce.Args) > 2 { if len(ce.Args) > 2 {
max, err = strconv.Atoi(ce.Args[2]) maxPerPage, err = strconv.Atoi(ce.Args[2])
if err != nil || max <= 0 { if err != nil || maxPerPage <= 0 {
ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2]) ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2])
return return
} else if max > 400 { } else if maxPerPage > 400 {
ce.Reply("Warning: a high number of items per page may fail to send a reply") ce.Reply("Warning: a high number of items per page may fail to send a reply")
} }
} }
@ -924,8 +946,8 @@ func fnList(ce *WrappedCommandEvent) {
ce.Reply("No %s found", strings.ToLower(typeName)) ce.Reply("No %s found", strings.ToLower(typeName))
return return
} }
pages := int(math.Ceil(float64(len(result)) / float64(max))) pages := int(math.Ceil(float64(len(result)) / float64(maxPerPage)))
if (page-1)*max >= len(result) { if (page-1)*maxPerPage >= len(result) {
if pages == 1 { if pages == 1 {
ce.Reply("There is only 1 page of %s", strings.ToLower(typeName)) ce.Reply("There is only 1 page of %s", strings.ToLower(typeName))
} else { } else {
@ -933,11 +955,11 @@ func fnList(ce *WrappedCommandEvent) {
} }
return return
} }
lastIndex := page * max lastIndex := page * maxPerPage
if lastIndex > len(result) { if lastIndex > len(result) {
lastIndex = len(result) lastIndex = len(result)
} }
result = result[(page-1)*max : lastIndex] result = result[(page-1)*maxPerPage : lastIndex]
ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n")) ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n"))
} }
@ -1036,13 +1058,13 @@ func fnOpen(ce *WrappedCommandEvent) {
} }
jid = newsletterMetadata.ID jid = newsletterMetadata.ID
} }
ce.Log.Debugln("Importing", jid, "for", ce.User.MXID) ce.ZLog.Debug().Stringer("chat_jid", jid).Msg("Importing chat for user")
portal := ce.User.GetPortalByJID(jid) portal := ce.User.GetPortalByJID(jid)
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
portal.UpdateMatrixRoom(ce.User, groupInfo, newsletterMetadata) portal.UpdateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata)
ce.Reply("Portal room synced.") ce.Reply("Portal room synced.")
} else { } else {
err = portal.CreateMatrixRoom(ce.User, groupInfo, newsletterMetadata, true, true) err = portal.CreateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata, true, true)
if err != nil { if err != nil {
ce.Reply("Failed to create room: %v", err) ce.Reply("Failed to create room: %v", err)
} else { } else {
@ -1085,7 +1107,7 @@ func fnPM(ce *WrappedCommandEvent) {
return return
} }
portal, puppet, justCreated, err := user.StartPM(targetUser.JID, "manual PM command") portal, puppet, justCreated, err := user.StartPM(ce.Ctx, targetUser.JID, "manual PM command")
if err != nil { if err != nil {
ce.Reply("Failed to create portal room: %v", err) ce.Reply("Failed to create portal room: %v", err)
} else if !justCreated { } else if !justCreated {
@ -1154,11 +1176,16 @@ func fnSync(ce *WrappedCommandEvent) {
ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge") ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge")
return return
} }
keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID) keys, err := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.Ctx, ce.User.JID)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to get list of private chats not in space")
ce.Reply("Failed to get list of private chats not in space")
return
}
count := 0 count := 0
for _, key := range keys { for _, key := range keys {
portal := ce.Bridge.GetPortalByJID(key) portal := ce.Bridge.GetPortalByJID(key)
portal.addToPersonalSpace(ce.User) portal.addToPersonalSpace(ce.Ctx, ce.User)
count++ count++
} }
plural := "s" plural := "s"
@ -1208,6 +1235,9 @@ func fnDisappearingTimer(ce *WrappedCommandEvent) {
ce.Portal.ExpirationTime = prevExpirationTime ce.Portal.ExpirationTime = prevExpirationTime
return return
} }
ce.Portal.Update(nil) err = ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting disappearing timer")
}
ce.React("✅") ce.React("✅")
} }

View file

@ -17,6 +17,9 @@
package main package main
import ( import (
"context"
"fmt"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@ -24,8 +27,11 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error
puppet.CustomMXID = mxid puppet.CustomMXID = mxid
puppet.AccessToken = accessToken puppet.AccessToken = accessToken
puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence
puppet.Update() err := puppet.Update(context.TODO())
err := puppet.StartCustomMXID(false) if err != nil {
return fmt.Errorf("failed to save access token: %w", err)
}
err = puppet.StartCustomMXID(false)
if err != nil { if err != nil {
return err return err
} }
@ -45,12 +51,15 @@ func (puppet *Puppet) ClearCustomMXID() {
puppet.customIntent = nil puppet.customIntent = nil
puppet.customUser = nil puppet.customUser = nil
if save { if save {
puppet.Update() err := puppet.Update(context.TODO())
if err != nil {
puppet.zlog.Err(err).Msg("Failed to clear custom MXID")
}
} }
} }
func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail) newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
if err != nil { if err != nil {
puppet.ClearCustomMXID() puppet.ClearCustomMXID()
return err return err
@ -60,11 +69,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
puppet.bridge.puppetsLock.Unlock() puppet.bridge.puppetsLock.Unlock()
if puppet.AccessToken != newAccessToken { if puppet.AccessToken != newAccessToken {
puppet.AccessToken = newAccessToken puppet.AccessToken = newAccessToken
puppet.Update() err = puppet.Update(context.TODO())
} }
puppet.customIntent = newIntent puppet.customIntent = newIntent
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID) puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
return nil return err
} }
func (user *User) tryAutomaticDoublePuppeting() { func (user *User) tryAutomaticDoublePuppeting() {

View file

@ -1,340 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillQuery struct {
db *Database
log log.Logger
backfillQueryLock sync.Mutex
}
func (bq *BackfillQuery) New() *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
Portal: &PortalKey{},
}
}
func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillQuery = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightQuery = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
)
// GetNext returns the next backfill to perform
func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
var types []string
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID, time.Now().Add(-15*time.Minute))
if err != nil || rows == nil {
bq.log.Errorfln("Failed to query next backfill queue job: %v", err)
return
}
defer rows.Close()
if rows.Next() {
backfill = bq.New().Scan(rows)
}
return
}
func (bq *BackfillQuery) HasUnstartedOrInFlightOfType(userID id.UserID, backfillTypes []BackfillType) bool {
if len(backfillTypes) == 0 {
return false
}
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
types := []string{}
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getUnstartedOrInFlightQuery, strings.Join(types, ",")), userID)
if err != nil || rows == nil {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
bq.log.Warnfln("Failed to query backfill queue jobs: %v", err)
}
// No rows means that there are no unstarted or in flight backfill
// requests.
return false
}
defer rows.Close()
return rows.Next()
}
func (bq *BackfillQuery) DeleteAll(userID id.UserID) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s: %v", userID, err)
}
}
func (bq *BackfillQuery) DeleteAllForPortal(userID id.UserID, portalKey PortalKey) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec(`
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`, userID, portalKey.JID, portalKey.Receiver)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s/%s: %v", userID, portalKey.JID, err)
}
}
type Backfill struct {
db *Database
log log.Logger
// Fields
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal *PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *Backfill) String() string {
return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &maxTotalEvents, &batchDelay)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b
}
func (b *Backfill) Insert() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
rows, err := b.db.Query(`
INSERT INTO backfill_queue
(user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt)
defer rows.Close()
if err != nil || !rows.Next() {
b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
return
}
err = rows.Scan(&b.QueueID)
if err != nil {
b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
}
}
func (b *Backfill) MarkDispatched() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err)
}
}
func (b *Backfill) MarkDone() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as complete: %v", b.BackfillType, b.Priority, err)
}
}
func (bq *BackfillQuery) NewBackfillState(userID id.UserID, portalKey *PortalKey) *BackfillState {
return &BackfillState{
db: bq.db,
log: bq.log,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillState = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
)
type BackfillState struct {
db *Database
log log.Logger
// Fields
UserID id.UserID
Portal *PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
return b
}
func (b *BackfillState) Upsert() {
_, err := b.db.Exec(`
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts`,
b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp)
if err != nil {
b.log.Warnfln("Failed to insert backfill state for %s: %v", b.Portal.JID, err)
}
}
func (b *BackfillState) SetProcessingBatch(processing bool) {
b.ProcessingBatch = processing
b.Upsert()
}
func (bq *BackfillQuery) GetBackfillState(userID id.UserID, portalKey *PortalKey) (backfillState *BackfillState) {
rows, err := bq.db.Query(getBackfillState, userID, portalKey.JID, portalKey.Receiver)
if err != nil || rows == nil {
bq.log.Error(err)
return
}
defer rows.Close()
if rows.Next() {
backfillState = bq.NewBackfillState(userID, portalKey).Scan(rows)
}
return
}

253
database/backfillqueue.go Normal file
View file

@ -0,0 +1,253 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillTaskQuery struct {
*dbutil.QueryHelper[*BackfillTask]
//backfillQueryLock sync.Mutex
}
func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask {
return &BackfillTask{qh: qh}
}
func (bq *BackfillTaskQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *BackfillTask {
return &BackfillTask{
qh: bq.QueryHelper,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillTaskQueryTemplate = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightBackfillTaskQueryTemplate = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
deleteBackfillQueueForUserQuery = "DELETE FROM backfill_queue WHERE user_mxid=$1"
deleteBackfillQueueForPortalQuery = `
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
insertBackfillTaskQuery = `
INSERT INTO backfill_queue (
user_mxid, type, priority, portal_jid, portal_receiver, time_start,
max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`
markBackfillTaskDispatchedQuery = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
markBackfillTaskDoneQuery = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
)
func typesToString(backfillTypes []BackfillType) string {
types := make([]string, len(backfillTypes))
for i, backfillType := range backfillTypes {
types[i] = strconv.Itoa(int(backfillType))
}
return strings.Join(types, ",")
}
// GetNext returns the next backfill to perform
func (bq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (*BackfillTask, error) {
if len(backfillTypes) == 0 {
return nil, nil
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getNextBackfillTaskQueryTemplate, typesToString(backfillTypes))
return bq.QueryOne(ctx, query, userID, time.Now().Add(-15*time.Minute))
}
func (bq *BackfillTaskQuery) HasUnstartedOrInFlightOfType(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (has bool) {
if len(backfillTypes) == 0 {
return false
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getUnstartedOrInFlightBackfillTaskQueryTemplate, typesToString(backfillTypes))
err := bq.GetDB().QueryRow(ctx, query, userID).Scan(&has)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).Msg("Failed to check if backfill queue has jobs")
}
return
}
func (bq *BackfillTaskQuery) DeleteAll(ctx context.Context, userID id.UserID) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForUserQuery, userID)
}
func (bq *BackfillTaskQuery) DeleteAllForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForPortalQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillTask struct {
qh *dbutil.QueryHelper[*BackfillTask]
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *BackfillTask) MarshalZerologObject(evt *zerolog.Event) {
evt.Int("queue_id", b.QueueID).
Stringer("user_id", b.UserID).
Stringer("backfill_type", b.BackfillType).
Int("priority", b.Priority).
Stringer("portal_jid", b.Portal.JID).
Any("time_start", b.TimeStart).
Int("max_batch_events", b.MaxBatchEvents).
Int("max_total_events", b.MaxTotalEvents).
Int("batch_delay", b.BatchDelay).
Any("dispatch_time", b.DispatchTime).
Any("completed_at", b.CompletedAt)
}
func (b *BackfillTask) String() string {
return fmt.Sprintf(
"BackfillTask{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(
&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart,
&b.MaxBatchEvents, &maxTotalEvents, &batchDelay,
)
if err != nil {
return nil, err
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b, nil
}
func (b *BackfillTask) sqlVariables() []any {
return []any{
b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart,
b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt,
}
}
func (b *BackfillTask) Insert(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
return b.qh.GetDB().QueryRow(ctx, insertBackfillTaskQuery, b.sqlVariables()...).Scan(&b.QueueID)
}
func (b *BackfillTask) MarkDispatched(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDispatchedQuery, time.Now(), b.QueueID)
}
func (b *BackfillTask) MarkDone(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDoneQuery, time.Now(), b.QueueID)
}

94
database/backfillstate.go Normal file
View file

@ -0,0 +1,94 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillStateQuery struct {
*dbutil.QueryHelper[*BackfillState]
}
func newBackfillState(qh *dbutil.QueryHelper[*BackfillState]) *BackfillState {
return &BackfillState{qh: qh}
}
func (bq *BackfillStateQuery) NewBackfillState(userID id.UserID, portalKey PortalKey) *BackfillState {
return &BackfillState{
qh: bq.QueryHelper,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillStateQuery = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
upsertBackfillStateQuery = `
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts
`
)
func (bq *BackfillStateQuery) GetBackfillState(ctx context.Context, userID id.UserID, portalKey PortalKey) (*BackfillState, error) {
return bq.QueryOne(ctx, getBackfillStateQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillState struct {
qh *dbutil.QueryHelper[*BackfillState]
UserID id.UserID
Portal PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) (*BackfillState, error) {
return dbutil.ValueOrErr(b, row.Scan(
&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp,
))
}
func (b *BackfillState) sqlVariables() []any {
return []any{b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp}
}
func (b *BackfillState) Upsert(ctx context.Context) error {
return b.qh.Exec(ctx, upsertBackfillStateQuery, b.sqlVariables()...)
}
func (b *BackfillState) SetProcessingBatch(ctx context.Context, processing bool) error {
b.ProcessingBatch = processing
return b.Upsert(ctx)
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -26,7 +26,6 @@ import (
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/store/sqlstore"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/database/upgrades" "maunium.net/go/mautrix-whatsapp/database/upgrades"
) )
@ -45,51 +44,28 @@ type Database struct {
Reaction *ReactionQuery Reaction *ReactionQuery
DisappearingMessage *DisappearingMessageQuery DisappearingMessage *DisappearingMessageQuery
Backfill *BackfillQuery BackfillQueue *BackfillTaskQuery
BackfillState *BackfillStateQuery
HistorySync *HistorySyncQuery HistorySync *HistorySyncQuery
MediaBackfillRequest *MediaBackfillRequestQuery MediaBackfillRequest *MediaBackfillRequestQuery
} }
func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { func New(db *dbutil.Database) *Database {
db := &Database{Database: baseDB}
db.UpgradeTable = upgrades.Table db.UpgradeTable = upgrades.Table
db.User = &UserQuery{ return &Database{
db: db, Database: db,
log: log.Sub("User"), User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)},
BackfillQueue: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)},
BackfillState: &BackfillStateQuery{dbutil.MakeQueryHelper(db, newBackfillState)},
HistorySync: &HistorySyncQuery{dbutil.MakeQueryHelper(db, newHistorySyncConversation)},
MediaBackfillRequest: &MediaBackfillRequestQuery{dbutil.MakeQueryHelper(db, newMediaBackfillRequest)},
} }
db.Portal = &PortalQuery{
db: db,
log: log.Sub("Portal"),
}
db.Puppet = &PuppetQuery{
db: db,
log: log.Sub("Puppet"),
}
db.Message = &MessageQuery{
db: db,
log: log.Sub("Message"),
}
db.Reaction = &ReactionQuery{
db: db,
log: log.Sub("Reaction"),
}
db.DisappearingMessage = &DisappearingMessageQuery{
db: db,
log: log.Sub("DisappearingMessage"),
}
db.Backfill = &BackfillQuery{
db: db,
log: log.Sub("Backfill"),
}
db.HistorySync = &HistorySyncQuery{
db: db,
log: log.Sub("HistorySync"),
}
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
db: db,
log: log.Sub("MediaBackfillRequest"),
}
return db
} }
func isRetryableError(err error) bool { func isRetryableError(err error) bool {

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,32 +17,29 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"errors"
"time" "time"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
) )
type DisappearingMessageQuery struct { type DisappearingMessageQuery struct {
db *Database *dbutil.QueryHelper[*DisappearingMessage]
log log.Logger
} }
func (dmq *DisappearingMessageQuery) New() *DisappearingMessage { func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
return &DisappearingMessage{ return &DisappearingMessage{
db: dmq.db, qh: qh,
log: dmq.log,
} }
} }
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage { func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
dm := &DisappearingMessage{ dm := &DisappearingMessage{
db: dmq.db, qh: dmq.QueryHelper,
log: dmq.log,
RoomID: roomID, RoomID: roomID,
EventID: eventID, EventID: eventID,
ExpireIn: expireIn, ExpireIn: expireIn,
@ -55,22 +52,17 @@ const (
getAllScheduledDisappearingMessagesQuery = ` getAllScheduledDisappearingMessagesQuery = `
SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1 SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1
` `
insertDisappearingMessageQuery = `INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`
updateDisappearingMessageExpiryQuery = "UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3"
deleteDisappearingMessageQuery = "DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2"
) )
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(duration time.Duration) (messages []*DisappearingMessage) { func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
rows, err := dmq.db.Query(getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli()) return dmq.QueryMany(ctx, getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, dmq.New().Scan(rows))
}
return
} }
type DisappearingMessage struct { type DisappearingMessage struct {
db *Database qh *dbutil.QueryHelper[*DisappearingMessage]
log log.Logger
RoomID id.RoomID RoomID id.RoomID
EventID id.EventID EventID id.EventID
@ -78,50 +70,33 @@ type DisappearingMessage struct {
ExpireAt time.Time ExpireAt time.Time
} }
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage { func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
var expireIn int64 var expireIn int64
var expireAt sql.NullInt64 var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt) err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { return nil, err
msg.log.Errorln("Database scan failed:", err)
}
return nil
} }
msg.ExpireIn = time.Duration(expireIn) * time.Millisecond msg.ExpireIn = time.Duration(expireIn) * time.Millisecond
if expireAt.Valid { if expireAt.Valid {
msg.ExpireAt = time.UnixMilli(expireAt.Int64) msg.ExpireAt = time.UnixMilli(expireAt.Int64)
} }
return msg return msg, nil
} }
func (msg *DisappearingMessage) Insert(txn dbutil.Execable) { func (msg *DisappearingMessage) sqlVariables() []any {
if txn == nil { return []any{msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), dbutil.UnixMilliPtr(msg.ExpireAt)}
txn = msg.db
}
var expireAt sql.NullInt64
if !msg.ExpireAt.IsZero() {
expireAt.Valid = true
expireAt.Int64 = msg.ExpireAt.UnixMilli()
}
_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt)
if err != nil {
msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)
}
} }
func (msg *DisappearingMessage) StartTimer() { func (msg *DisappearingMessage) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
}
func (msg *DisappearingMessage) StartTimer(ctx context.Context) error {
msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second) msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second)
_, err := msg.db.Exec("UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID) return msg.qh.Exec(ctx, updateDisappearingMessageExpiryQuery, msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err)
}
} }
func (msg *DisappearingMessage) Delete() { func (msg *DisappearingMessage) Delete(ctx context.Context) error {
_, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2", msg.RoomID, msg.EventID) return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan, Sumner Evans // Copyright (C) 2024 Tulir Asokan, Sumner Evans
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,7 @@
package database package database
import ( import (
"database/sql" "context"
"errors"
"fmt" "fmt"
"time" "time"
@ -26,23 +25,19 @@ import (
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
type HistorySyncQuery struct { type HistorySyncQuery struct {
db *Database *dbutil.QueryHelper[*HistorySyncConversation]
log log.Logger
} }
type HistorySyncConversation struct { type HistorySyncConversation struct {
db *Database qh *dbutil.QueryHelper[*HistorySyncConversation]
log log.Logger
UserID id.UserID UserID id.UserID
ConversationID string ConversationID string
PortalKey *PortalKey PortalKey PortalKey
LastMessageTimestamp time.Time LastMessageTimestamp time.Time
MuteEndTime time.Time MuteEndTime time.Time
Archived bool Archived bool
@ -54,18 +49,16 @@ type HistorySyncConversation struct {
UnreadCount uint32 UnreadCount uint32
} }
func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation { func newHistorySyncConversation(qh *dbutil.QueryHelper[*HistorySyncConversation]) *HistorySyncConversation {
return &HistorySyncConversation{ return &HistorySyncConversation{
db: hsq.db, qh: qh,
log: hsq.log,
PortalKey: &PortalKey{},
} }
} }
func (hsq *HistorySyncQuery) NewConversationWithValues( func (hsq *HistorySyncQuery) NewConversationWithValues(
userID id.UserID, userID id.UserID,
conversationID string, conversationID string,
portalKey *PortalKey, portalKey PortalKey,
lastMessageTimestamp, lastMessageTimestamp,
muteEndTime uint64, muteEndTime uint64,
archived bool, archived bool,
@ -74,10 +67,10 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType, endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType,
ephemeralExpiration *uint32, ephemeralExpiration *uint32,
markedAsUnread bool, markedAsUnread bool,
unreadCount uint32) *HistorySyncConversation { unreadCount uint32,
) *HistorySyncConversation {
return &HistorySyncConversation{ return &HistorySyncConversation{
db: hsq.db, qh: hsq.QueryHelper,
log: hsq.log,
UserID: userID, UserID: userID,
ConversationID: conversationID, ConversationID: conversationID,
PortalKey: portalKey, PortalKey: portalKey,
@ -94,6 +87,17 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
} }
const ( const (
upsertHistorySyncConversationQuery = `
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`
getNMostRecentConversations = ` getNMostRecentConversations = `
SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
FROM history_sync_conversation FROM history_sync_conversation
@ -108,24 +112,19 @@ const (
AND portal_jid=$2 AND portal_jid=$2
AND portal_receiver=$3 AND portal_receiver=$3
` `
deleteAllConversationsQuery = "DELETE FROM history_sync_conversation WHERE user_mxid=$1"
deleteHistorySyncConversationQuery = `
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`
) )
func (hsc *HistorySyncConversation) Upsert() { func (hsc *HistorySyncConversation) sqlVariables() []any {
_, err := hsc.db.Exec(` return []any{
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`,
hsc.UserID, hsc.UserID,
hsc.ConversationID, hsc.ConversationID,
hsc.PortalKey.JID.String(), hsc.PortalKey.JID,
hsc.PortalKey.Receiver.String(), hsc.PortalKey.Receiver,
hsc.LastMessageTimestamp, hsc.LastMessageTimestamp,
hsc.Archived, hsc.Archived,
hsc.Pinned, hsc.Pinned,
@ -134,14 +133,16 @@ func (hsc *HistorySyncConversation) Upsert() {
hsc.EndOfHistoryTransferType, hsc.EndOfHistoryTransferType,
hsc.EphemeralExpiration, hsc.EphemeralExpiration,
hsc.MarkedAsUnread, hsc.MarkedAsUnread,
hsc.UnreadCount) hsc.UnreadCount,
if err != nil {
hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
} }
} }
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation { func (hsc *HistorySyncConversation) Upsert(ctx context.Context) error {
err := row.Scan( return hsc.qh.Exec(ctx, upsertHistorySyncConversationQuery, hsc.sqlVariables()...)
}
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConversation, error) {
return dbutil.ValueOrErr(hsc, row.Scan(
&hsc.UserID, &hsc.UserID,
&hsc.ConversationID, &hsc.ConversationID,
&hsc.PortalKey.JID, &hsc.PortalKey.JID,
@ -154,69 +155,59 @@ func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConve
&hsc.EndOfHistoryTransferType, &hsc.EndOfHistoryTransferType,
&hsc.EphemeralExpiration, &hsc.EphemeralExpiration,
&hsc.MarkedAsUnread, &hsc.MarkedAsUnread,
&hsc.UnreadCount) &hsc.UnreadCount,
if err != nil { ))
if !errors.Is(err, sql.ErrNoRows) {
hsc.log.Errorln("Database scan failed:", err)
}
return nil
}
return hsc
} }
func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) { func (hsq *HistorySyncQuery) GetRecentConversations(ctx context.Context, userID id.UserID, n int) ([]*HistorySyncConversation, error) {
nPtr := &n nPtr := &n
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit. // Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
if n < 0 && hsq.db.Dialect == dbutil.Postgres { if n < 0 && hsq.GetDB().Dialect == dbutil.Postgres {
nPtr = nil nPtr = nil
} }
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr) return hsq.QueryMany(ctx, getNMostRecentConversations, userID, nPtr)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
conversations = append(conversations, hsq.NewConversation().Scan(rows))
}
return
} }
func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) { func (hsq *HistorySyncQuery) GetConversation(ctx context.Context, userID id.UserID, portalKey PortalKey) (*HistorySyncConversation, error) {
rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver) return hsq.QueryOne(ctx, getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
if rows.Next() {
conversation = hsq.NewConversation().Scan(rows)
}
return
} }
func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) { func (hsq *HistorySyncQuery) DeleteAllConversations(ctx context.Context, userID id.UserID) error {
_, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID) return hsq.Exec(ctx, deleteAllConversationsQuery, userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err)
}
} }
const ( const (
getMessagesBetween = ` insertHistorySyncMessageQuery = `
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`
getHistorySyncMessagesBetweenQueryTemplate = `
SELECT data FROM history_sync_message SELECT data FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2 WHERE user_mxid=$1 AND conversation_id=$2
%s %s
ORDER BY timestamp DESC ORDER BY timestamp DESC
%s %s
` `
deleteMessagesBetweenExclusive = ` deleteHistorySyncMessagesBetweenExclusiveQuery = `
DELETE FROM history_sync_message DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4 WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4
` `
deleteAllHistorySyncMessagesQuery = "DELETE FROM history_sync_message WHERE user_mxid=$1"
deleteHistorySyncMessagesForPortalQuery = `
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`
conversationHasHistorySyncMessagesQuery = `
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`
) )
type HistorySyncMessage struct { type HistorySyncMessage struct {
db *Database hsq *HistorySyncQuery
log log.Logger
UserID id.UserID UserID id.UserID
ConversationID string ConversationID string
@ -231,8 +222,8 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
return nil, err return nil, err
} }
return &HistorySyncMessage{ return &HistorySyncMessage{
db: hsq.db, hsq: hsq,
log: hsq.log,
UserID: userID, UserID: userID,
ConversationID: conversationID, ConversationID: conversationID,
MessageID: messageID, MessageID: messageID,
@ -241,18 +232,27 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
}, nil }, nil
} }
func (hsm *HistorySyncMessage) Insert() error { func (hsm *HistorySyncMessage) Insert(ctx context.Context) error {
_, err := hsm.db.Exec(` return hsm.hsq.Exec(ctx, insertHistorySyncMessageQuery, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
return err
} }
func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) { func scanWebMessageInfo(rows dbutil.Scannable) (*waProto.WebMessageInfo, error) {
var msgData []byte
err := rows.Scan(&msgData)
if err != nil {
return nil, err
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
return historySyncMsg.GetMessage(), nil
}
func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) ([]*waProto.WebMessageInfo, error) {
whereClauses := "" whereClauses := ""
args := []interface{}{userID, conversationID} args := []any{userID, conversationID}
argNum := 3 argNum := 3
if startTime != nil { if startTime != nil {
whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum) whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
@ -268,80 +268,35 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
if limit > 0 { if limit > 0 {
limitClause = fmt.Sprintf("LIMIT %d", limit) limitClause = fmt.Sprintf("LIMIT %d", limit)
} }
query := fmt.Sprintf(getHistorySyncMessagesBetweenQueryTemplate, whereClauses, limitClause)
rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...) return dbutil.ConvertRowFn[*waProto.WebMessageInfo](scanWebMessageInfo).
defer rows.Close() NewRowIter(hsq.GetDB().Query(ctx, query, args...)).
if err != nil || rows == nil { AsList()
if err != nil && !errors.Is(err, sql.ErrNoRows) {
hsq.log.Warnfln("Failed to query messages between range: %v", err)
}
return nil
}
var msgData []byte
for rows.Next() {
err = rows.Scan(&msgData)
if err != nil {
hsq.log.Errorfln("Database scan failed: %v", err)
continue
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err)
continue
}
messages = append(messages, historySyncMsg.Message)
}
return
} }
func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error { func (hsq *HistorySyncQuery) DeleteMessages(ctx context.Context, userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
newest := messages[0] newest := messages[0]
beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0) beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0)
oldest := messages[len(messages)-1] oldest := messages[len(messages)-1]
afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0) afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0)
_, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS) return hsq.Exec(ctx, deleteHistorySyncMessagesBetweenExclusiveQuery, userID, conversationID, beforeTS, afterTS)
return err
} }
func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) { func (hsq *HistorySyncQuery) DeleteAllMessages(ctx context.Context, userID id.UserID) error {
_, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID) return hsq.Exec(ctx, deleteAllHistorySyncMessagesQuery, userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err)
}
} }
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) { func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
_, err := hsq.db.Exec(` return hsq.Exec(ctx, deleteHistorySyncMessagesForPortalQuery, userID, portalKey.JID)
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, portalKey.JID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err)
}
} }
func (hsq *HistorySyncQuery) ConversationHasMessages(userID id.UserID, portalKey PortalKey) (exists bool) { func (hsq *HistorySyncQuery) ConversationHasMessages(ctx context.Context, userID id.UserID, portalKey PortalKey) (exists bool, err error) {
err := hsq.db.QueryRow(` err = hsq.GetDB().QueryRow(ctx, conversationHasHistorySyncMessagesQuery, userID, portalKey.JID).Scan(&exists)
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`, userID, portalKey.JID).Scan(&exists)
if err != nil {
hsq.log.Warnfln("Failed to check if any messages are stored for %s/%s: %v", userID, portalKey.JID, err)
}
return return
} }
func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) { func (hsq *HistorySyncQuery) DeleteConversation(ctx context.Context, userID id.UserID, jid string) error {
// This will also clear history_sync_message as there's a foreign key constraint // This will also clear history_sync_message as there's a foreign key constraint
_, err := hsq.db.Exec(` return hsq.Exec(ctx, deleteHistorySyncConversationQuery, userID, jid)
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, jid)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan, Sumner Evans // Copyright (C) 2024 Tulir Asokan, Sumner Evans
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,14 +17,12 @@
package database package database
import ( import (
"database/sql" "context"
"errors"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
) )
type MediaBackfillRequestStatus int type MediaBackfillRequestStatus int
@ -36,34 +34,46 @@ const (
) )
type MediaBackfillRequestQuery struct { type MediaBackfillRequestQuery struct {
db *Database *dbutil.QueryHelper[*MediaBackfillRequest]
log log.Logger
} }
type MediaBackfillRequest struct { const (
db *Database getAllMediaBackfillRequestsForUserQuery = `
log log.Logger SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
deleteAllMediaBackfillRequestsForUserQuery = "DELETE FROM media_backfill_requests WHERE user_mxid=$1"
upsertMediaBackfillRequestQuery = `
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
DO UPDATE SET
media_key=excluded.media_key,
status=excluded.status,
error=excluded.error
`
)
UserID id.UserID func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(ctx context.Context, userID id.UserID) ([]*MediaBackfillRequest, error) {
PortalKey *PortalKey return mbrq.QueryMany(ctx, getAllMediaBackfillRequestsForUserQuery, userID)
EventID id.EventID
MediaKey []byte
Status MediaBackfillRequestStatus
Error string
} }
func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest { func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(ctx context.Context, userID id.UserID) error {
return mbrq.Exec(ctx, deleteAllMediaBackfillRequestsForUserQuery, userID)
}
func newMediaBackfillRequest(qh *dbutil.QueryHelper[*MediaBackfillRequest]) *MediaBackfillRequest {
return &MediaBackfillRequest{ return &MediaBackfillRequest{
db: mbrq.db, qh: qh,
log: mbrq.log,
PortalKey: &PortalKey{},
} }
} }
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest { func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
return &MediaBackfillRequest{ return &MediaBackfillRequest{
db: mbrq.db, qh: mbrq.QueryHelper,
log: mbrq.log,
UserID: userID, UserID: userID,
PortalKey: portalKey, PortalKey: portalKey,
EventID: eventID, EventID: eventID,
@ -72,62 +82,25 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID
} }
} }
const ( type MediaBackfillRequest struct {
getMediaBackfillRequestsForUser = ` qh *dbutil.QueryHelper[*MediaBackfillRequest]
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
)
func (mbr *MediaBackfillRequest) Upsert() { UserID id.UserID
_, err := mbr.db.Exec(` PortalKey PortalKey
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error) EventID id.EventID
VALUES ($1, $2, $3, $4, $5, $6, $7) MediaKey []byte
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id) Status MediaBackfillRequestStatus
DO UPDATE SET Error string
media_key=EXCLUDED.media_key,
status=EXCLUDED.status,
error=EXCLUDED.error`,
mbr.UserID,
mbr.PortalKey.JID.String(),
mbr.PortalKey.Receiver.String(),
mbr.EventID,
mbr.MediaKey,
mbr.Status,
mbr.Error)
if err != nil {
mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err)
}
} }
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest { func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) (*MediaBackfillRequest, error) {
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error) return dbutil.ValueOrErr(mbr, row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error))
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
mbr.log.Errorln("Database scan failed:", err)
}
return nil
}
return mbr
} }
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) { func (mbr *MediaBackfillRequest) sqlVariables() []any {
rows, err := mbrq.db.Query(getMediaBackfillRequestsForUser, userID) return []any{mbr.UserID, mbr.PortalKey.JID, mbr.PortalKey.Receiver, mbr.EventID, mbr.MediaKey, mbr.Status, mbr.Error}
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
requests = append(requests, mbrq.newMediaBackfillRequest().Scan(rows))
}
return
} }
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) { func (mbr *MediaBackfillRequest) Upsert(ctx context.Context) error {
_, err := mbrq.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID) return mbr.qh.Exec(ctx, upsertMediaBackfillRequestQuery, mbr.sqlVariables()...)
if err != nil {
mbrq.log.Warnfln("Failed to delete media backfill requests for %s: %v", userID, err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,29 +17,22 @@
package database package database
import ( import (
"database/sql" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
type MessageQuery struct { type MessageQuery struct {
db *Database *dbutil.QueryHelper[*Message]
log log.Logger
} }
func (mq *MessageQuery) New() *Message { func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
return &Message{ return &Message{qh: qh}
db: mq.db,
log: mq.log,
}
} }
const ( const (
@ -67,60 +60,47 @@ const (
SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
` `
insertMessageQuery = `
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
markMessageSentQuery = "UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4"
updateMessageMXIDQuery = "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
deleteMessageQuery = "DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3"
) )
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { func (mq *MessageQuery) GetAll(ctx context.Context, chat PortalKey) ([]*Message, error) {
rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver) return mq.QueryMany(ctx, getAllMessagesQuery, chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
} }
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message { func (mq *MessageQuery) GetByJID(ctx context.Context, chat PortalKey, jid types.MessageID) (*Message, error) {
return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid)) return mq.QueryOne(ctx, getMessageByJIDQuery, chat.JID, chat.Receiver, jid)
} }
func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message { func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid)) return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
} }
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message { func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second)) return mq.GetLastInChatBefore(ctx, chat, time.Now().Add(60*time.Second))
} }
func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message { func (mq *MessageQuery) GetLastInChatBefore(ctx context.Context, chat PortalKey, maxTimestamp time.Time) (*Message, error) {
msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())) msg, err := mq.QueryOne(ctx, getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())
if msg == nil || msg.Timestamp.IsZero() { if msg != nil && msg.Timestamp.IsZero() {
// Old db, we don't know what the last message is. // Old db, we don't know what the last message is.
return nil msg = nil
} }
return msg return msg, err
} }
func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message { func (mq *MessageQuery) GetFirstInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver)) return mq.QueryOne(ctx, getFirstMessageInChatQuery, chat.JID, chat.Receiver)
} }
func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) { func (mq *MessageQuery) GetMessagesBetween(ctx context.Context, chat PortalKey, minTimestamp, maxTimestamp time.Time) ([]*Message, error) {
rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix()) return mq.QueryMany(ctx, getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
}
func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
if row == nil {
return nil
}
return mq.New().Scan(row)
} }
type MessageErrorType string type MessageErrorType string
@ -144,8 +124,7 @@ const (
) )
type Message struct { type Message struct {
db *Database qh *dbutil.QueryHelper[*Message]
log log.Logger
Chat PortalKey Chat PortalKey
JID types.MessageID JID types.MessageID
@ -172,76 +151,49 @@ func (msg *Message) IsFakeJID() bool {
const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s" const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s"
func (msg *Message) Scan(row dbutil.Scannable) *Message { func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64 var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { return nil, err
msg.log.Errorln("Database scan failed:", err)
}
return nil
} }
if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") { if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") {
_, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID) _, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID)
if err != nil { if err != nil {
msg.log.Errorln("Parsing gallery MXID failed:", err) return nil, fmt.Errorf("failed to parse gallery MXID: %w", err)
} }
} }
if ts != 0 { if ts != 0 {
msg.Timestamp = time.Unix(ts, 0) msg.Timestamp = time.Unix(ts, 0)
} }
return msg return msg, nil
} }
func (msg *Message) Insert(txn dbutil.Execable) { func (msg *Message) sqlVariables() []any {
if txn == nil {
txn = msg.db
}
var sender interface{} = msg.Sender
// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
if msg.Sender.IsEmpty() {
sender = ""
}
mxid := msg.MXID.String() mxid := msg.MXID.String()
if msg.GalleryPart != 0 { if msg.GalleryPart != 0 {
mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid) mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid)
} }
_, err := txn.Exec(` return []any{msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, msg.Sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID}
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
} }
func (msg *Message) MarkSent(ts time.Time) { func (msg *Message) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
}
func (msg *Message) MarkSent(ctx context.Context, ts time.Time) error {
msg.Sent = true msg.Sent = true
msg.Timestamp = ts msg.Timestamp = ts
_, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID) return msg.qh.Exec(ctx, markMessageSentQuery, ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
} }
func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) { func (msg *Message) UpdateMXID(ctx context.Context, mxid id.EventID, newType MessageType, newError MessageErrorType) error {
if txn == nil {
txn = msg.db
}
msg.MXID = mxid msg.MXID = mxid
msg.Type = newType msg.Type = newType
msg.Error = newError msg.Error = newError
_, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6", return msg.qh.Exec(ctx, updateMessageMXIDQuery, mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
} }
func (msg *Message) Delete() { func (msg *Message) Delete(ctx context.Context) error {
_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID) return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,28 +17,56 @@
package database package database
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
) )
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) { const (
var hash []byte bulkPutPollOptionsQuery = "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
err = rows.Scan(&id, &hash) bulkPutPollOptionsQueryTemplate = "($1, $%d, $%d)"
if err != nil { bulkPutPollOptionsQueryPlaceholder = "($1, $2, $3)"
// return below getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
} else if len(hash) != 32 { getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
err = fmt.Errorf("unexpected hash length %d", len(hash)) getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
} else { getPollOptionQueryArrayPlaceholder = " = ANY($2)"
hashArr = *(*[32]byte)(hash) )
func init() {
if strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, "meow") == bulkPutPollOptionsQuery {
panic("Bulk insert query placeholder not found")
}
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
} }
return
} }
func (msg *Message) PutPollOptions(opts map[[32]byte]string) { type pollOption struct {
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)" id string
hash [32]byte
}
func scanPollOption(rows dbutil.Scannable) (*pollOption, error) {
var hash []byte
var id string
err := rows.Scan(&id, &hash)
if err != nil {
return nil, err
} else if len(hash) != 32 {
return nil, fmt.Errorf("unexpected hash length %d", len(hash))
} else {
return &pollOption{id: id, hash: [32]byte(hash)}, nil
}
}
func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string) error {
args := make([]any, len(opts)*2+1) args := make([]any, len(opts)*2+1)
placeholders := make([]string, len(opts)) placeholders := make([]string, len(opts))
args[0] = msg.MXID args[0] = msg.MXID
@ -47,72 +75,47 @@ func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
args[i*2+1] = id args[i*2+1] = id
hashCopy := hash hashCopy := hash
args[i*2+2] = hashCopy[:] args[i*2+2] = hashCopy[:]
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3) placeholders[i] = fmt.Sprintf(bulkPutPollOptionsQueryTemplate, i*2+2, i*2+3)
i++ i++
} }
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ",")) query := strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, strings.Join(placeholders, ","))
_, err := msg.db.Exec(query, args...) return msg.qh.Exec(ctx, query, args...)
if err != nil {
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
}
} }
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string { func getPollOptions[LookupKey any, Key comparable, Value any](
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)" ctx context.Context,
msg *Message,
query string,
things []LookupKey,
getKeyValue func(option *pollOption) (Key, Value),
) (map[Key]Value, error) {
var args []any var args []any
if msg.db.Dialect == dbutil.Postgres { if msg.qh.GetDB().Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(hashes)} args = []any{msg.MXID, pq.Array(things)}
} else { } else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ","))) query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
args = make([]any, len(hashes)+1) args = make([]any, len(things)+1)
args[0] = msg.MXID args[0] = msg.MXID
for i, hash := range hashes { for i, thing := range things {
args[i+1] = hash args[i+1] = thing
} }
} }
ids := make(map[[32]byte]string, len(hashes)) return dbutil.RowIterAsMap(
rows, err := msg.db.Query(query, args...) dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(msg.qh.GetDB().Query(ctx, query, args...)),
if err != nil { getKeyValue,
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err) )
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
break
}
ids[hash] = id
}
}
return ids
} }
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte { func (msg *Message) GetPollOptionIDs(ctx context.Context, hashes [][]byte) (map[[32]byte]string, error) {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)" return getPollOptions(
var args []any ctx, msg, getPollOptionIDsByHashesQuery, hashes,
if msg.db.Dialect == dbutil.Postgres { func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
args = []any{msg.MXID, pq.Array(ids)} )
} else { }
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
args = make([]any, len(ids)+1) func (msg *Message) GetPollOptionHashes(ctx context.Context, ids []string) (map[string][32]byte, error) {
args[0] = msg.MXID return getPollOptions(
for i, id := range ids { ctx, msg, getPollOptionHashesByIDsQuery, ids,
args[i+1] = id func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
} )
}
hashes := make(map[string][32]byte, len(ids))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
break
}
hashes[id] = hash
}
}
return hashes
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,15 +17,14 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
) )
type PortalKey struct { type PortalKey struct {
@ -53,90 +52,89 @@ func (key PortalKey) String() string {
} }
type PortalQuery struct { type PortalQuery struct {
db *Database *dbutil.QueryHelper[*Portal]
log log.Logger
} }
func (pq *PortalQuery) New() *Portal { func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
return &Portal{ return &Portal{
db: pq.db, qh: qh,
log: pq.log,
} }
} }
const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time" const (
getAllPortalsQuery = `
func (pq *PortalQuery) GetAll() []*Portal { SELECT jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns)) encrypted, last_sync, is_parent, parent_group, in_space,
} first_event_id, next_batch_id, relay_user_id, expiration_time
FROM portal
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { `
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver) getPortalByJIDQuery = getAllPortalsQuery + " WHERE jid=$1 AND receiver=$2"
} getPortalByMXIDQuery = getAllPortalsQuery + " WHERE mxid=$1"
getPrivateChatsWithQuery = getAllPortalsQuery + " WHERE jid=$1"
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { getPrivateChatsOfQuery = getAllPortalsQuery + " WHERE receiver=$1"
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid) getAllPortalsByParentGroupQuery = getAllPortalsQuery + " WHERE parent_group=$1"
} findPrivateChatPortalsNotInSpaceQuery = `
func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid)
}
func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
receiver = receiver.ToNonAD()
rows, err := pq.db.Query(`
SELECT jid FROM portal SELECT jid FROM portal
LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL) WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL)
`, receiver) `
if err != nil {
pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err) insertPortalQuery = `
return INSERT INTO portal (
} else if rows == nil { jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
return encrypted, last_sync, is_parent, parent_group, in_space,
} first_event_id, next_batch_id, relay_user_id, expiration_time
for rows.Next() { ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
var key PortalKey `
updatePortalQuery = `
UPDATE portal
SET mxid=$3, name=$4, name_set=$5, topic=$6, topic_set=$7, avatar=$8, avatar_url=$9, avatar_set=$10,
encrypted=$11, last_sync=$12, is_parent=$13, parent_group=$14, in_space=$15,
first_event_id=$16, next_batch_id=$17, relay_user_id=$18, expiration_time=$19
WHERE jid=$1 AND receiver=$2
`
clearPortalInSpaceQuery = "UPDATE portal SET in_space=false WHERE parent_group=$1"
deletePortalQuery = "DELETE FROM portal WHERE jid=$1 AND receiver=$2"
)
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery)
}
func (pq *PortalQuery) GetByJID(ctx context.Context, key PortalKey) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByJIDQuery, key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
}
func (pq *PortalQuery) GetAllByJID(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsWithQuery, jid.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChats(ctx context.Context, receiver types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsOfQuery, receiver.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsByParentGroupQuery, jid)
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver types.JID) (keys []PortalKey, err error) {
receiver = receiver.ToNonAD()
scanFn := func(rows dbutil.Scannable) (key PortalKey, err error) {
key.Receiver = receiver key.Receiver = receiver
err = rows.Scan(&key.JID) err = rows.Scan(&key.JID)
if err == nil { return
keys = append(keys, key)
}
} }
return return dbutil.ConvertRowFn[PortalKey](scanFn).
} NewRowIter(pq.GetDB().Query(ctx, findPrivateChatPortalsNotInSpaceQuery, receiver)).
AsList()
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
} }
type Portal struct { type Portal struct {
db *Database qh *dbutil.QueryHelper[*Portal]
log log.Logger
Key PortalKey Key PortalKey
MXID id.RoomID MXID id.RoomID
@ -161,15 +159,17 @@ type Portal struct {
ExpirationTime uint32 ExpirationTime uint32
} }
func (portal *Portal) Scan(row dbutil.Scannable) *Portal { func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
var lastSyncTs int64 var lastSyncTs int64
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime) err := row.Scan(
&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet,
&portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted,
&lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace,
&firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime,
)
if err != nil { if err != nil {
if err != sql.ErrNoRows { return nil, err
portal.log.Errorln("Database scan failed:", err)
}
return nil
} }
if lastSyncTs > 0 { if lastSyncTs > 0 {
portal.LastSync = time.Unix(lastSyncTs, 0) portal.LastSync = time.Unix(lastSyncTs, 0)
@ -182,96 +182,36 @@ func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
portal.FirstEventID = id.EventID(firstEventID.String) portal.FirstEventID = id.EventID(firstEventID.String)
portal.NextBatchID = id.BatchID(nextBatchID.String) portal.NextBatchID = id.BatchID(nextBatchID.String)
portal.RelayUserID = id.UserID(relayUserID.String) portal.RelayUserID = id.UserID(relayUserID.String)
return portal return portal, nil
} }
func (portal *Portal) mxidPtr() *id.RoomID { func (portal *Portal) sqlVariables() []any {
if len(portal.MXID) > 0 { var lastSyncTS int64
return &portal.MXID if !portal.LastSync.IsZero() {
lastSyncTS = portal.LastSync.Unix()
} }
return nil return []any{
} portal.Key.JID, portal.Key.Receiver, dbutil.StrPtr(portal.MXID), portal.Name, portal.NameSet,
portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted,
func (portal *Portal) relayUserPtr() *id.UserID { lastSyncTS, portal.IsParent, dbutil.StrPtr(portal.ParentGroup.String()), portal.InSpace,
if len(portal.RelayUserID) > 0 { portal.FirstEventID.String(), portal.NextBatchID.String(), dbutil.StrPtr(portal.RelayUserID), portal.ExpirationTime,
return &portal.RelayUserID
}
return nil
}
func (portal *Portal) parentGroupPtr() *string {
if !portal.ParentGroup.IsEmpty() {
val := portal.ParentGroup.String()
return &val
}
return nil
}
func (portal *Portal) lastSyncTs() int64 {
if portal.LastSync.IsZero() {
return 0
}
return portal.LastSync.Unix()
}
func (portal *Portal) Insert() {
_, err := portal.db.Exec(`
INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
relay_user_id, expiration_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`,
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
portal.relayUserPtr(), portal.ExpirationTime)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
} }
} }
func (portal *Portal) Update(txn dbutil.Execable) { func (portal *Portal) Insert(ctx context.Context) error {
if txn == nil { return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
txn = portal.db
}
_, err := txn.Exec(`
UPDATE portal
SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
WHERE jid=$18 AND receiver=$19
`, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
}
} }
func (portal *Portal) Delete() { func (portal *Portal) Update(ctx context.Context) error {
txn, err := portal.db.Begin() return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
if err != nil { }
portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
return func (portal *Portal) Delete(ctx context.Context) error {
} return portal.qh.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
defer func() { err := portal.qh.Exec(ctx, clearPortalInSpaceQuery, portal.Key.JID)
if err != nil { if err != nil {
err = txn.Rollback() return err
if err != nil {
portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
}
} else if err = txn.Commit(); err != nil {
portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
} }
}() return portal.qh.Exec(ctx, deletePortalQuery, portal.Key.JID, portal.Key.Receiver)
_, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID) })
if err != nil {
portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
return
}
_, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,74 +17,70 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"time" "time"
"go.mau.fi/util/dbutil" "github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
) )
type PuppetQuery struct { type PuppetQuery struct {
db *Database *dbutil.QueryHelper[*Puppet]
log log.Logger
} }
func (pq *PuppetQuery) New() *Puppet { func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
return &Puppet{ return &Puppet{
db: pq.db, qh: qh,
log: pq.log,
EnablePresence: true, EnablePresence: true,
EnableReceipts: true, EnableReceipts: true,
} }
} }
func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { const (
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet") getAllPuppetsQuery = `
if err != nil || rows == nil { SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set,
return nil last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts
} FROM puppet
defer rows.Close() `
for rows.Next() { getPuppetByJIDQuery = getAllPuppetsQuery + " WHERE username=$1"
puppets = append(puppets, pq.New().Scan(rows)) getPuppetByCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid=$1"
} getAllPuppetsWithCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid<>''"
return insertPuppetQuery = `
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
updatePuppetQuery = `
UPDATE puppet
SET avatar=$2, avatar_url=$3, avatar_set=$4, displayname=$5, name_quality=$6, name_set=$7, contact_info_set=$8,
last_sync=$9, custom_mxid=$10, access_token=$11, next_batch=$12, enable_presence=$13, enable_receipts=$14
WHERE username=$1
`
)
func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) {
return pq.QueryMany(ctx, getAllPuppetsQuery)
} }
func (pq *PuppetQuery) Get(jid types.JID) *Puppet { func (pq *PuppetQuery) Get(ctx context.Context, jid types.JID) (*Puppet, error) {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User) return pq.QueryOne(ctx, getPuppetByJIDQuery, jid.User)
if row == nil {
return nil
}
return pq.New().Scan(row)
} }
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid) return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid)
if row == nil {
return nil
}
return pq.New().Scan(row)
} }
func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) { func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) {
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''") return pq.QueryMany(ctx, getAllPuppetsWithCustomMXIDQuery)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return
} }
type Puppet struct { type Puppet struct {
db *Database qh *dbutil.QueryHelper[*Puppet]
log log.Logger
JID types.JID JID types.JID
Avatar string Avatar string
@ -103,17 +99,14 @@ type Puppet struct {
EnableReceipts bool EnableReceipts bool
} }
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet { func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality, lastSync sql.NullInt64 var quality, lastSync sql.NullInt64
var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool
var username string var username string
err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts) err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
if err != nil { if err != nil {
if err != sql.ErrNoRows { return nil, err
puppet.log.Errorln("Database scan failed:", err)
}
return nil
} }
puppet.JID = types.NewJID(username, types.DefaultUserServer) puppet.JID = types.NewJID(username, types.DefaultUserServer)
puppet.Displayname = displayname.String puppet.Displayname = displayname.String
@ -131,45 +124,30 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
puppet.NextBatch = nextBatch.String puppet.NextBatch = nextBatch.String
puppet.EnablePresence = enablePresence.Bool puppet.EnablePresence = enablePresence.Bool
puppet.EnableReceipts = enableReceipts.Bool puppet.EnableReceipts = enableReceipts.Bool
return puppet return puppet, nil
} }
func (puppet *Puppet) Insert() { func (puppet *Puppet) sqlVariables() []any {
if puppet.JID.Server != types.DefaultUserServer { var lastSyncTS int64
puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID)
return
}
var lastSyncTs int64
if !puppet.LastSync.IsZero() { if !puppet.LastSync.IsZero() {
lastSyncTs = puppet.LastSync.Unix() lastSyncTS = puppet.LastSync.Unix()
} }
_, err := puppet.db.Exec(` return []any{
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts) puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTS,
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
`, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
puppet.EnablePresence, puppet.EnableReceipts, puppet.EnablePresence, puppet.EnableReceipts,
)
if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
} }
} }
func (puppet *Puppet) Update() { func (puppet *Puppet) Insert(ctx context.Context) error {
var lastSyncTs int64 if puppet.JID.Server != types.DefaultUserServer {
if !puppet.LastSync.IsZero() { zerolog.Ctx(ctx).Warn().Stringer("jid", puppet.JID).Msg("Not inserting puppet: not a user")
lastSyncTs = puppet.LastSync.Unix() return nil
}
_, err := puppet.db.Exec(`
UPDATE puppet
SET displayname=$1, name_quality=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, contact_info_set=$7, last_sync=$8,
custom_mxid=$9, access_token=$10, next_batch=$11, enable_presence=$12, enable_receipts=$13
WHERE username=$14
`, puppet.Displayname, puppet.NameQuality, puppet.NameSet, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.ContactInfoSet,
lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts,
puppet.JID.User)
if err != nil {
puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
} }
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
}
func (puppet *Puppet) Update(ctx context.Context) error {
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,26 +17,20 @@
package database package database
import ( import (
"database/sql" "context"
"errors"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
) )
type ReactionQuery struct { type ReactionQuery struct {
db *Database *dbutil.QueryHelper[*Reaction]
log log.Logger
} }
func (rq *ReactionQuery) New() *Reaction { func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
return &Reaction{ return &Reaction{qh: qh}
db: rq.db,
log: rq.log,
}
} }
const ( const (
@ -55,28 +49,20 @@ const (
DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid
` `
deleteReactionQuery = ` deleteReactionQuery = `
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 AND mxid=$5 DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4
` `
) )
func (rq *ReactionQuery) GetByTargetJID(chat PortalKey, jid types.MessageID, sender types.JID) *Reaction { func (rq *ReactionQuery) GetByTargetJID(ctx context.Context, chat PortalKey, jid types.MessageID, sender types.JID) (*Reaction, error) {
return rq.maybeScan(rq.db.QueryRow(getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())) return rq.QueryOne(ctx, getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())
} }
func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction { func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid)) return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
}
func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction {
if row == nil {
return nil
}
return rq.New().Scan(row)
} }
type Reaction struct { type Reaction struct {
db *Database qh *dbutil.QueryHelper[*Reaction]
log log.Logger
Chat PortalKey Chat PortalKey
TargetJID types.MessageID TargetJID types.MessageID
@ -85,35 +71,19 @@ type Reaction struct {
JID types.MessageID JID types.MessageID
} }
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction { func (reaction *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID) return dbutil.ValueOrErr(reaction, row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID))
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
reaction.log.Errorln("Database scan failed:", err)
}
return nil
}
return reaction
} }
func (reaction *Reaction) Upsert(txn dbutil.Execable) { func (reaction *Reaction) sqlVariables() []any {
reaction.Sender = reaction.Sender.ToNonAD() reaction.Sender = reaction.Sender.ToNonAD()
if txn == nil { return []any{reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID}
txn = reaction.db
}
_, err := txn.Exec(upsertReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID)
if err != nil {
reaction.log.Warnfln("Failed to upsert reaction to %s@%s by %s: %v", reaction.Chat, reaction.TargetJID, reaction.Sender, err)
}
} }
func (reaction *Reaction) GetTarget() *Message { func (reaction *Reaction) Upsert(ctx context.Context) error {
return reaction.db.Message.GetByJID(reaction.Chat, reaction.TargetJID) return reaction.qh.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...)
} }
func (reaction *Reaction) Delete() { func (reaction *Reaction) Delete(ctx context.Context) error {
_, err := reaction.db.Exec(deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID) return reaction.qh.Exec(ctx, deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender)
if err != nil {
reaction.log.Warnfln("Failed to delete reaction %s: %v", reaction.MXID, err)
}
} }

View file

@ -17,6 +17,7 @@
package upgrades package upgrades
import ( import (
"context"
"embed" "embed"
"errors" "errors"
@ -29,7 +30,7 @@ var Table dbutil.UpgradeTable
var rawUpgrades embed.FS var rawUpgrades embed.FS
func init() { func init() {
Table.Register(-1, 35, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error { Table.Register(-1, 35, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version") return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
}) })
Table.RegisterFS(rawUpgrades) Table.RegisterFS(rawUpgrades)

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,63 +17,65 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"sync" "sync"
"time" "time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
) )
type UserQuery struct { type UserQuery struct {
db *Database *dbutil.QueryHelper[*User]
log log.Logger
} }
func (uq *UserQuery) New() *User { func newUser(qh *dbutil.QueryHelper[*User]) *User {
return &User{ return &User{
db: uq.db, qh: qh,
log: uq.log,
lastReadCache: make(map[PortalKey]time.Time), lastReadCache: make(map[PortalKey]time.Time),
inSpaceCache: make(map[PortalKey]bool), inSpaceCache: make(map[PortalKey]bool),
} }
} }
func (uq *UserQuery) GetAll() (users []*User) { const (
rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`) getAllUsersQuery = `SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`
if err != nil || rows == nil { getUserByMXIDQuery = getAllUsersQuery + ` WHERE mxid=$1`
return nil getUserByUsernameQuery = getAllUsersQuery + ` WHERE username=$1`
} insertUserQuery = `
defer rows.Close() INSERT INTO "user" (
for rows.Next() { mxid, username, agent, device,
users = append(users, uq.New().Scan(rows)) management_room, space_room,
} phone_last_seen, phone_last_pinged, timezone
return ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
updateUserQuery = `
UPDATE "user"
SET username=$1, agent=$2, device=$3,
management_room=$4, space_room=$5,
phone_last_seen=$6, phone_last_pinged=$7, timezone=$8
WHERE mxid=$9
`
getUserLastAppStateKeyIDQuery = "SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1"
)
func (uq *UserQuery) GetAll(ctx context.Context) ([]*User, error) {
return uq.QueryMany(ctx, getAllUsersQuery)
} }
func (uq *UserQuery) GetByMXID(userID id.UserID) *User { func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID) return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
} }
func (uq *UserQuery) GetByUsername(username string) *User { func (uq *UserQuery) GetByUsername(ctx context.Context, username string) (*User, error) {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username) return uq.QueryOne(ctx, getUserByUsernameQuery, username)
if row == nil {
return nil
}
return uq.New().Scan(row)
} }
type User struct { type User struct {
db *Database qh *dbutil.QueryHelper[*User]
log log.Logger
MXID id.UserID MXID id.UserID
JID types.JID JID types.JID
@ -89,20 +91,21 @@ type User struct {
inSpaceCacheLock sync.Mutex inSpaceCacheLock sync.Mutex
} }
func (user *User) Scan(row dbutil.Scannable) *User { func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var username, timezone sql.NullString var username, timezone sql.NullString
var device, agent sql.NullByte var device, agent sql.NullInt16
var phoneLastSeen, phoneLastPinged sql.NullInt64 var phoneLastSeen, phoneLastPinged sql.NullInt64
err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone) err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone)
if err != nil { if err != nil {
if err != sql.ErrNoRows { return nil, err
user.log.Errorln("Database scan failed:", err)
}
return nil
} }
user.Timezone = timezone.String user.Timezone = timezone.String
if len(username.String) > 0 { if len(username.String) > 0 {
user.JID = types.NewADJID(username.String, agent.Byte, device.Byte) user.JID = types.JID{
User: username.String,
Device: uint16(device.Int16),
Server: types.DefaultUserServer,
}
} }
if phoneLastSeen.Valid { if phoneLastSeen.Valid {
user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0) user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
@ -110,66 +113,34 @@ func (user *User) Scan(row dbutil.Scannable) *User {
if phoneLastPinged.Valid { if phoneLastPinged.Valid {
user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0) user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0)
} }
return user return user, nil
} }
func (user *User) usernamePtr() *string { func (user *User) sqlVariables() []any {
var username *string
var agent, device *uint16
if !user.JID.IsEmpty() { if !user.JID.IsEmpty() {
return &user.JID.User username = dbutil.StrPtr(user.JID.User)
var zero uint16
agent = &zero
device = dbutil.NumPtr(user.JID.Device)
} }
return nil return []any{
} username, agent, device, user.ManagementRoom, user.SpaceRoom,
dbutil.UnixPtr(user.PhoneLastSeen), dbutil.UnixPtr(user.PhoneLastPinged),
func (user *User) agentPtr() *uint8 { user.Timezone, user.MXID,
if !user.JID.IsEmpty() {
zero := uint8(0)
return &zero
}
return nil
}
func (user *User) devicePtr() *uint8 {
if !user.JID.IsEmpty() {
device := uint8(user.JID.Device)
return &device
}
return nil
}
func (user *User) phoneLastSeenPtr() *int64 {
if user.PhoneLastSeen.IsZero() {
return nil
}
ts := user.PhoneLastSeen.Unix()
return &ts
}
func (user *User) phoneLastPingedPtr() *int64 {
if user.PhoneLastPinged.IsZero() {
return nil
}
ts := user.PhoneLastPinged.Unix()
return &ts
}
func (user *User) Insert() {
_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone)
if err != nil {
user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
} }
} }
func (user *User) Update() { func (user *User) Insert(ctx context.Context) error {
_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`, return user.qh.Exec(ctx, insertUserQuery, user.sqlVariables()...)
user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID)
if err != nil {
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
}
} }
func (user *User) GetLastAppStateKeyID() ([]byte, error) { func (user *User) Update(ctx context.Context) error {
var keyID []byte return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
err := user.db.QueryRow("SELECT key_id FROM whatsmeow_app_state_sync_keys ORDER BY timestamp DESC LIMIT 1").Scan(&keyID) }
return keyID, err
func (user *User) GetLastAppStateKeyID(ctx context.Context) (keyID []byte, err error) {
err = user.qh.GetDB().QueryRow(ctx, getUserLastAppStateKeyIDQuery, user.JID).Scan(&keyID)
return
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,69 +17,97 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"time" "time"
"github.com/rs/zerolog"
) )
func (user *User) GetLastReadTS(portal PortalKey) time.Time { const (
getLastReadTSQuery = "SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setLastReadTSQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`
getIsInSpaceQuery = "SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setIsInSpaceQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`
)
func (user *User) GetLastReadTS(ctx context.Context, portal PortalKey) time.Time {
user.lastReadCacheLock.Lock() user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock() defer user.lastReadCacheLock.Unlock()
if cached, ok := user.lastReadCache[portal]; ok { if cached, ok := user.lastReadCache[portal]; ok {
return cached return cached
} }
var ts int64 var ts int64
err := user.db.QueryRow("SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&ts) var parsedTS time.Time
err := user.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, user.MXID, portal.JID, portal.Receiver).Scan(&ts)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
user.log.Warnfln("Failed to scan last read timestamp from user portal table: %v", err) zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query last read timestamp")
return parsedTS
} }
if ts == 0 { if ts != 0 {
user.lastReadCache[portal] = time.Time{} parsedTS = time.Unix(ts, 0)
} else {
user.lastReadCache[portal] = time.Unix(ts, 0)
} }
user.lastReadCache[portal] = parsedTS
return user.lastReadCache[portal] return user.lastReadCache[portal]
} }
func (user *User) SetLastReadTS(portal PortalKey, ts time.Time) { func (user *User) SetLastReadTS(ctx context.Context, portal PortalKey, ts time.Time) {
user.lastReadCacheLock.Lock() user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock() defer user.lastReadCacheLock.Unlock()
_, err := user.db.Exec(` _, err := user.qh.GetDB().Exec(ctx, setLastReadTSQuery, user.MXID, portal.JID, portal.Receiver, ts.Unix())
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`, user.MXID, portal.JID, portal.Receiver, ts.Unix())
if err != nil { if err != nil {
user.log.Warnfln("Failed to update last read timestamp: %v", err) zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update last read timestamp")
} else { } else {
user.log.Debugfln("Set last read timestamp of %s in %s to %d", user.MXID, portal.String(), ts.Unix()) zerolog.Ctx(ctx).Debug().
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Time("last_read_ts", ts).
Msg("Updated last read timestamp of portal")
user.lastReadCache[portal] = ts user.lastReadCache[portal] = ts
} }
} }
func (user *User) IsInSpace(portal PortalKey) bool { func (user *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
user.inSpaceCacheLock.Lock() user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock() defer user.inSpaceCacheLock.Unlock()
if cached, ok := user.inSpaceCache[portal]; ok { if cached, ok := user.inSpaceCache[portal]; ok {
return cached return cached
} }
var inSpace bool var inSpace bool
err := user.db.QueryRow("SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&inSpace) err := user.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
user.log.Warnfln("Failed to scan in space status from user portal table: %v", err) zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query in space status")
return false
} }
user.inSpaceCache[portal] = inSpace user.inSpaceCache[portal] = inSpace
return inSpace return inSpace
} }
func (user *User) MarkInSpace(portal PortalKey) { func (user *User) MarkInSpace(ctx context.Context, portal PortalKey) {
user.inSpaceCacheLock.Lock() user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock() defer user.inSpaceCacheLock.Unlock()
_, err := user.db.Exec(` _, err := user.qh.GetDB().Exec(ctx, setIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver)
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`, user.MXID, portal.JID, portal.Receiver)
if err != nil { if err != nil {
user.log.Warnfln("Failed to update in space status: %v", err) zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update in space status")
} else { } else {
user.inSpaceCache[portal] = true user.inSpaceCache[portal] = true
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,11 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"time" "time"
"go.mau.fi/util/dbutil" "github.com/rs/zerolog"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@ -28,47 +29,74 @@ import (
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
) )
func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) { func (portal *Portal) MarkDisappearing(ctx context.Context, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
if expiresIn == 0 { if expiresIn == 0 {
return return
} }
expiresAt := startsAt.Add(expiresIn) expiresAt := startsAt.Add(expiresIn)
msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, expiresIn, expiresAt) msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, expiresIn, expiresAt)
msg.Insert(txn) err := msg.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to insert disappearing message")
}
if expiresAt.Before(time.Now().Add(1 * time.Hour)) { if expiresAt.Before(time.Now().Add(1 * time.Hour)) {
go portal.sleepAndDelete(msg) go portal.sleepAndDelete(context.WithoutCancel(ctx), msg)
} }
} }
func (br *WABridge) SleepAndDeleteUpcoming() { func (br *WABridge) SleepAndDeleteUpcoming(ctx context.Context) {
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) { msgs, err := br.DB.DisappearingMessage.GetUpcomingScheduled(ctx, 1*time.Hour)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get upcoming disappearing messages")
return
}
for _, msg := range msgs {
portal := br.GetPortalByMXID(msg.RoomID) portal := br.GetPortalByMXID(msg.RoomID)
if portal == nil { if portal == nil {
msg.Delete() err = msg.Delete(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", msg.EventID).
Msg("Failed to delete disappearing message row with no portal")
}
} else { } else {
go portal.sleepAndDelete(msg) go portal.sleepAndDelete(ctx, msg)
} }
} }
} }
func (portal *Portal) sleepAndDelete(msg *database.DisappearingMessage) { func (portal *Portal) sleepAndDelete(ctx context.Context, msg *database.DisappearingMessage) {
if _, alreadySleeping := portal.currentlySleepingToDelete.LoadOrStore(msg.EventID, true); alreadySleeping { if _, alreadySleeping := portal.currentlySleepingToDelete.LoadOrStore(msg.EventID, true); alreadySleeping {
return return
} }
defer portal.currentlySleepingToDelete.Delete(msg.EventID) defer portal.currentlySleepingToDelete.Delete(msg.EventID)
log := zerolog.Ctx(ctx)
sleepTime := msg.ExpireAt.Sub(time.Now()) sleepTime := msg.ExpireAt.Sub(time.Now())
portal.log.Debugfln("Sleeping for %s to make %s disappear", sleepTime, msg.EventID) log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Dur("sleep_time", sleepTime).
Msg("Sleeping before making message disappear")
time.Sleep(sleepTime) time.Sleep(sleepTime)
_, err := portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{ _, err := portal.MainIntent().RedactEvent(ctx, msg.RoomID, msg.EventID, mautrix.ReqRedact{
Reason: "Message expired", Reason: "Message expired",
TxnID: fmt.Sprintf("mxwa_disappear_%s", msg.EventID), TxnID: fmt.Sprintf("mxwa_disappear_%s", msg.EventID),
}) })
if err != nil { if err != nil {
portal.log.Warnfln("Failed to make %s disappear: %v", msg.EventID, err) log.Err(err).
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Failed to make event disappear")
} else { } else {
portal.log.Debugfln("Disappeared %s", msg.EventID) log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Disappeared event")
}
err = msg.Delete(ctx)
if err != nil {
log.Err(err).Msg("Failed to delete disapperaing message row in database after redacting event")
} }
msg.Delete()
} }

View file

@ -17,12 +17,14 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"html" "html"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@ -104,22 +106,27 @@ func NewFormatter(bridge *WABridge) *Formatter {
return formatter return formatter
} }
func (formatter *Formatter) getMatrixInfoByJID(roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) { func (formatter *Formatter) getMatrixInfoByJID(ctx context.Context, roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil { if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
mxid = puppet.MXID mxid = puppet.MXID
displayname = puppet.Displayname displayname = puppet.Displayname
} }
if user := formatter.bridge.GetUserByJID(jid); user != nil { if user := formatter.bridge.GetUserByJID(jid); user != nil {
mxid = user.MXID mxid = user.MXID
member := formatter.bridge.StateStore.GetMember(roomID, user.MXID) member, err := formatter.bridge.StateStore.GetMember(ctx, roomID, user.MXID)
if len(member.Displayname) > 0 { if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("user_id", user.MXID).
Msg("Failed to get member profile from state store")
} else if len(member.Displayname) > 0 {
displayname = member.Displayname displayname = member.Displayname
} }
} }
return return
} }
func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) { func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
output := html.EscapeString(content.Body) output := html.EscapeString(content.Body)
for regex, replacement := range formatter.waReplString { for regex, replacement := range formatter.waReplString {
output = regex.ReplaceAllString(output, replacement) output = regex.ReplaceAllString(output, replacement)
@ -145,7 +152,7 @@ func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.Messa
// TODO lid support? // TODO lid support?
continue continue
} }
mxid, displayname := formatter.getMatrixInfoByJID(roomID, jid) mxid, displayname := formatter.getMatrixInfoByJID(ctx, roomID, jid)
number := "@" + jid.User number := "@" + jid.User
output = strings.ReplaceAll(output, number, fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname)) output = strings.ReplaceAll(output, number, fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname))
content.Body = strings.ReplaceAll(content.Body, number, displayname) content.Body = strings.ReplaceAll(content.Body, number, displayname)

45
go.mod
View file

@ -1,25 +1,25 @@
module maunium.net/go/mautrix-whatsapp module maunium.net/go/mautrix-whatsapp
go 1.20 go 1.21
require ( require (
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.19 github.com/mattn/go-sqlite3 v1.14.22
github.com/prometheus/client_golang v1.17.0 github.com/prometheus/client_golang v1.19.0
github.com/rs/zerolog v1.31.0 github.com/rs/zerolog v1.32.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/tidwall/gjson v1.17.0 github.com/tidwall/gjson v1.17.1
go.mau.fi/util v0.2.1 go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e
go.mau.fi/webp v0.1.0 go.mau.fi/webp v0.1.0
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/image v0.14.0 golang.org/x/image v0.15.0
golang.org/x/net v0.19.0 golang.org/x/net v0.22.0
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.33.0
maunium.net/go/maulogger/v2 v2.4.1 maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa
maunium.net/go/mautrix v0.16.2
) )
require ( require (
@ -27,24 +27,27 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.19 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect github.com/rs/xid v1.5.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.6.0 // indirect github.com/yuin/goldmark v1.7.0 // indirect
go.mau.fi/libsignal v0.1.0 // indirect go.mau.fi/libsignal v0.1.0 // indirect
go.mau.fi/zeroconfig v0.1.2 // indirect go.mau.fi/zeroconfig v0.1.2 // indirect
golang.org/x/crypto v0.16.0 // indirect golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sys v0.15.0 // indirect golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
maunium.net/go/mauflag v1.0.0 // indirect maunium.net/go/mauflag v1.0.0 // indirect
) )
//replace go.mau.fi/util => ../../Go/go-util
//replace maunium.net/go/mautrix => ../mautrix-go

100
go.sum
View file

@ -1,6 +1,9 @@
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c h1:WqjRVgUO039eiISCjsZC4F9onOEV93DJAk6v33rsZzY=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c/go.mod h1:b9FFm9y4mEm36G8ytVmS1vkNzJa0KepmcdVY+qf7qRU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
@ -9,18 +12,18 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -30,78 +33,75 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c= go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c=
go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I= go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I=
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e h1:e1jDj/MjleSS5r9DMRbuCZYKy5Rr+sbsu8eWjtLqrGk=
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo=
go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA= go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA=
go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8= go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 h1:+teJYCOK6M4Kn2TYCj29levhHVwnJTmgCtEXLtgwQtM= go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 h1:QOGjCIh2WEfkgX/38KLjnNof79GWx0T+KLrhTHiws3s=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735/go.mod h1:5xTtHNaZpGni6z6aE1iEopjW7wNgsKcolZxZrOujK9M= go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462/go.mod h1:lQHbhaG/fI+6hfGqz5Vzn2OBJBEZ05H0kCP6iJXriN4=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 h1:qCEDpW1G+vcj3Y7Fy52pEM1AWm3abj8WimGYejI3SC4= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa h1:TLSWIAWKIWxLghgzWfp7o92pVCcFR3yLsArc0s/tsMs=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa/go.mod h1:0sfLB2ejW+lhgio4UlZMmn5i9SuZ8mxFkonFSamrfTE=
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=

View file

@ -17,6 +17,7 @@
package main package main
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
@ -24,11 +25,11 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/variationselector"
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"go.mau.fi/util/variationselector"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@ -64,9 +65,8 @@ func (user *User) handleHistorySyncsLoop() {
if batchSend { if batchSend {
// Start the backfill queue. // Start the backfill queue.
user.BackfillQueue = &BackfillQueue{ user.BackfillQueue = &BackfillQueue{
BackfillQuery: user.bridge.DB.Backfill, BackfillQuery: user.bridge.DB.BackfillQueue,
reCheckChannels: []chan bool{}, reCheckChannels: []chan bool{},
log: user.log.Sub("BackfillQueue"),
} }
forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward} forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward}
@ -109,80 +109,113 @@ func (user *User) handleHistorySyncsLoop() {
const EnqueueBackfillsDelay = 30 * time.Second const EnqueueBackfillsDelay = 30 * time.Second
func (user *User) enqueueAllBackfills() { func (user *User) enqueueAllBackfills() {
nMostRecent := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations) log := user.zlog.With().
if len(nMostRecent) > 0 { Str("method", "User.enqueueAllBackfills").
user.log.Infofln("%v has passed since the last history sync blob, enqueueing backfills for %d chats", EnqueueBackfillsDelay, len(nMostRecent)) Logger()
// Find the portals for all the conversations. ctx := log.WithContext(context.TODO())
portals := []*Portal{} nMostRecent, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
for _, conv := range nMostRecent { if err != nil {
jid, err := types.ParseJID(conv.ConversationID) log.Err(err).Msg("Failed to get recent history sync conversations from database")
if err != nil { return
user.log.Warnfln("Failed to parse chat JID '%s' in history sync: %v", conv.ConversationID, err) } else if len(nMostRecent) == 0 {
continue return
}
portals = append(portals, user.GetPortalByJID(jid))
}
user.EnqueueImmediateBackfills(portals)
user.EnqueueForwardBackfills(portals)
user.EnqueueDeferredBackfills(portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
} }
log.Info().
Int("chat_count", len(nMostRecent)).
Msg("Enqueueing backfills for recent chats in history sync")
// Find the portals for all the conversations.
portals := make([]*Portal, 0, len(nMostRecent))
for _, conv := range nMostRecent {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).Str("conversation_id", conv.ConversationID).Msg("Failed to parse chat JID in history sync")
continue
}
portals = append(portals, user.GetPortalByJID(jid))
}
user.EnqueueImmediateBackfills(ctx, portals)
user.EnqueueForwardBackfills(ctx, portals)
user.EnqueueDeferredBackfills(ctx, portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
} }
func (user *User) backfillAll() { func (user *User) backfillAll() {
conversations := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, -1) log := user.zlog.With().
if len(conversations) > 0 { Str("method", "User.backfillAll").
user.zlog.Info(). Logger()
Int("conversation_count", len(conversations)). ctx := log.WithContext(context.TODO())
Msg("Probably received all history sync blobs, now backfilling conversations") conversations, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, -1)
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations if err != nil {
bridgedCount := 0 log.Err(err).Msg("Failed to get history sync conversations from database")
// Find the portals for all the conversations. return
for _, conv := range conversations { } else if len(conversations) == 0 {
jid, err := types.ParseJID(conv.ConversationID) return
}
log.Info().
Int("conversation_count", len(conversations)).
Msg("Probably received all history sync blobs, now backfilling conversations")
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
bridgedCount := 0
// Find the portals for all the conversations.
for _, conv := range conversations {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).
Str("conversation_id", conv.ConversationID).
Msg("Failed to parse chat JID in history sync")
continue
}
portal := user.GetPortalByJID(jid)
if portal.MXID != "" {
log.Debug().
Str("portal_jid", portal.Key.JID.String()).
Msg("Chat already has a room, deleting messages from database")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil { if err != nil {
user.zlog.Warn().Err(err). log.Err(err).Str("portal_jid", portal.Key.JID.String()).
Str("conversation_id", conv.ConversationID). Msg("Failed to delete history sync conversation with existing portal from database")
Msg("Failed to parse chat JID in history sync")
continue
} }
portal := user.GetPortalByJID(jid) bridgedCount++
if portal.MXID != "" { } else if hasMessages, err := user.bridge.DB.HistorySync.ConversationHasMessages(ctx, user.MXID, portal.Key); err != nil {
user.zlog.Debug(). log.Err(err).Str("portal_jid", portal.Key.JID.String()).Msg("Failed to check if chat has messages in history sync")
Str("portal_jid", portal.Key.JID.String()). } else if !hasMessages {
Msg("Chat already has a room, deleting messages from database") log.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
bridgedCount++ if err != nil {
} else if !user.bridge.DB.HistorySync.ConversationHasMessages(user.MXID, portal.Key) { log.Err(err).Str("portal_jid", portal.Key.JID.String()).
user.zlog.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync") Msg("Failed to delete history sync conversation with no messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) }
} else if limit < 0 || bridgedCount < limit { } else if limit < 0 || bridgedCount < limit {
bridgedCount++ bridgedCount++
err = portal.CreateMatrixRoom(user, nil, nil, true, true) err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, true)
if err != nil { if err != nil {
user.zlog.Err(err).Msg("Failed to create Matrix room for backfill") log.Err(err).Msg("Failed to create Matrix room for backfill")
}
} }
} }
} }
} }
func (portal *Portal) legacyBackfill(user *User) { func (portal *Portal) legacyBackfill(ctx context.Context, user *User) {
defer portal.latestEventBackfillLock.Unlock() defer portal.latestEventBackfillLock.Unlock()
// This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room. // This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room.
if portal.latestEventBackfillLock.TryLock() { if portal.latestEventBackfillLock.TryLock() {
panic("legacyBackfill() called without locking latestEventBackfillLock") panic("legacyBackfill() called without locking latestEventBackfillLock")
} }
// TODO use portal.zlog instead of user.zlog log := zerolog.Ctx(ctx).With().Str("action", "legacy backfill").Logger()
log := user.zlog.With(). ctx = log.WithContext(ctx)
Str("portal_jid", portal.Key.JID.String()). conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, portal.Key)
Str("action", "legacy backfill"). if err != nil {
Logger() log.Err(err).Msg("Failed to get history sync conversation data for backfill")
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, portal.Key) return
messages := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount) }
messages, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
if err != nil {
log.Err(err).Msg("Failed to get history sync messages for backfill")
return
}
log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database") log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database")
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i]) msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i])
@ -194,18 +227,26 @@ func (portal *Portal) legacyBackfill(user *User) {
Msg("Dropping historical message due to parse error") Msg("Dropping historical message due to parse error")
continue continue
} }
portal.handleMessage(user, msgEvt, true) ctx := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger().
WithContext(ctx)
portal.handleMessage(ctx, user, msgEvt, true)
} }
if conv != nil { if conv != nil {
isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0 isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour)) isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead := !isUnread || isTooOld shouldMarkAsRead := !isUnread || isTooOld
if shouldMarkAsRead { if shouldMarkAsRead {
user.markSelfReadFull(portal) user.markSelfReadFull(ctx, portal)
} }
} }
log.Debug().Msg("Backfill complete, deleting leftover messages from database") log.Info().Msg("Backfill complete, deleting leftover messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
log.Err(err).Msg("Failed to delete history sync conversation from database after backfill")
}
} }
func (user *User) dailyMediaRequestLoop() { func (user *User) dailyMediaRequestLoop() {
@ -224,29 +265,49 @@ func (user *User) dailyMediaRequestLoop() {
if requestStartTime.Before(now) { if requestStartTime.Before(now) {
requestStartTime = requestStartTime.AddDate(0, 0, 1) requestStartTime = requestStartTime.AddDate(0, 0, 1)
} }
log := user.zlog.With().
Str("action", "daily media request loop").
Logger()
ctx := log.WithContext(context.Background())
// Wait to start the loop // Wait to start the loop
user.log.Infof("Waiting until %s to do media retry requests", requestStartTime) log.Info().Time("start_loop_at", requestStartTime).Msg("Waiting until start time to do media retry requests")
time.Sleep(time.Until(requestStartTime)) time.Sleep(time.Until(requestStartTime))
for { for {
mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID) mediaBackfillRequests, err := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(ctx, user.MXID)
user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests)) if err != nil {
log.Err(err).Msg("Failed to get media retry requests")
} else if len(mediaBackfillRequests) > 0 {
log.Info().Int("media_request_count", len(mediaBackfillRequests)).Msg("Sending media retry requests")
// Send all of the media backfill requests for the user at once // Send all the media backfill requests for the user at once
for _, req := range mediaBackfillRequests { for _, req := range mediaBackfillRequests {
portal := user.GetPortalByJID(req.PortalKey.JID) portal := user.GetPortalByJID(req.PortalKey.JID)
_, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey) _, err = portal.requestMediaRetry(ctx, user, req.EventID, req.MediaKey)
if err != nil { if err != nil {
user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID) log.Err(err).
req.Status = database.MediaBackfillRequestStatusRequestFailed Stringer("portal_key", req.PortalKey).
req.Error = err.Error() Stringer("event_id", req.EventID).
} else { Msg("Failed to send media retry request")
user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID) req.Status = database.MediaBackfillRequestStatusRequestFailed
req.Status = database.MediaBackfillRequestStatusRequested req.Error = err.Error()
} else {
log.Debug().
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Sent media retry request")
req.Status = database.MediaBackfillRequestStatusRequested
}
req.MediaKey = nil
err = req.Upsert(ctx)
if err != nil {
log.Err(err).
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Failed to save status of media retry request")
}
} }
req.MediaKey = nil
req.Upsert()
} }
// Wait for 24 hours before making requests again // Wait for 24 hours before making requests again
@ -254,20 +315,29 @@ func (user *User) dailyMediaRequestLoop() {
} }
} }
func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) { func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTask, conv *database.HistorySyncConversation, portal *Portal) {
portal.backfillLock.Lock() portal.backfillLock.Lock()
defer portal.backfillLock.Unlock() defer portal.backfillLock.Unlock()
log := zerolog.Ctx(ctx)
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(portal.MXID, user.MXID) { if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(ctx, portal.MXID, user.MXID) {
portal.ensureUserInvited(user) portal.ensureUserInvited(ctx, user)
} }
backfillState := user.bridge.DB.Backfill.GetBackfillState(user.MXID, &portal.Key) backfillState, err := user.bridge.DB.BackfillState.GetBackfillState(ctx, user.MXID, portal.Key)
if backfillState == nil { if backfillState == nil {
backfillState = user.bridge.DB.Backfill.NewBackfillState(user.MXID, &portal.Key) backfillState = user.bridge.DB.BackfillState.NewBackfillState(user.MXID, portal.Key)
} }
backfillState.SetProcessingBatch(true) err = backfillState.SetProcessingBatch(ctx, true)
defer backfillState.SetProcessingBatch(false) if err != nil {
log.Err(err).Msg("Failed to mark batch as being processed")
}
defer func() {
err = backfillState.SetProcessingBatch(ctx, false)
if err != nil {
log.Err(err).Msg("Failed to mark batch as no longer being processed")
}
}()
var timeEnd *time.Time var timeEnd *time.Time
var forward, shouldMarkAsRead bool var forward, shouldMarkAsRead bool
@ -275,17 +345,27 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
if req.BackfillType == database.BackfillForward { if req.BackfillType == database.BackfillForward {
// TODO this overrides the TimeStart set when enqueuing the backfill // TODO this overrides the TimeStart set when enqueuing the backfill
// maybe the enqueue should instead include the prev event ID // maybe the enqueue should instead include the prev event ID
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key) lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get newest message in chat")
return
}
start := lastMessage.Timestamp.Add(1 * time.Second) start := lastMessage.Timestamp.Add(1 * time.Second)
req.TimeStart = &start req.TimeStart = &start
// Sending events at the end of the room (= latest events) // Sending events at the end of the room (= latest events)
forward = true forward = true
} else { } else {
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key) firstMessage, err := portal.bridge.DB.Message.GetFirstInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get oldest message in chat")
return
}
if firstMessage != nil { if firstMessage != nil {
end := firstMessage.Timestamp.Add(-1 * time.Second) end := firstMessage.Timestamp.Add(-1 * time.Second)
timeEnd = &end timeEnd = &end
user.log.Debugfln("Limiting backfill to end at %v", end) log.Debug().
Time("oldest_message_ts", firstMessage.Timestamp).
Msg("Limiting backfill to messages older than oldest message")
} else { } else {
// Portal is empty -> events are latest // Portal is empty -> events are latest
forward = true forward = true
@ -303,45 +383,48 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour)) isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead = !isUnread || isTooOld shouldMarkAsRead = !isUnread || isTooOld
} }
allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents) allMsgs, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
sendDisappearedNotice := false sendDisappearedNotice := false
// If expired messages are on, and a notice has not been sent to this chat // If expired messages are on, and a notice has not been sent to this chat
// about it having disappeared messages at the conversation timestamp, send // about it having disappeared messages at the conversation timestamp, send
// a notice indicating so. // a notice indicating so.
if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 { if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 {
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key) lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get last message in chat to check if disappeared notice should be sent")
}
if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) { if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) {
sendDisappearedNotice = true sendDisappearedNotice = true
} }
} }
if !sendDisappearedNotice && len(allMsgs) == 0 { if !sendDisappearedNotice && len(allMsgs) == 0 {
user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID) log.Debug().Msg("Not backfilling chat: no bridgeable messages found")
return return
} }
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
user.log.Debugln("Creating portal for", portal.Key.JID, "as part of history sync handling") log.Debug().Msg("Creating portal for chat as part of history sync handling")
err := portal.CreateMatrixRoom(user, nil, nil, true, false) err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, false)
if err != nil { if err != nil {
user.log.Errorfln("Failed to create room for %s during backfill: %v", portal.Key.JID, err) log.Err(err).Msg("Failed to create room for chat during backfill")
return return
} }
} }
// Update the backfill status here after the room has been created. // Update the backfill status here after the room has been created.
portal.updateBackfillStatus(backfillState) portal.updateBackfillStatus(ctx, backfillState)
if sendDisappearedNotice { if sendDisappearedNotice {
user.log.Debugfln("Sending notice to %s that there are disappeared messages ending at %v", portal.Key.JID, conv.LastMessageTimestamp) log.Debug().Time("last_message_time", conv.LastMessageTimestamp).
resp, err := portal.sendMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ Msg("Sending notice that there are disappeared messages in the chat")
resp, err := portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice, MsgType: event.MsgNotice,
Body: portal.formatDisappearingMessageNotice(), Body: portal.formatDisappearingMessageNotice(),
}, nil, conv.LastMessageTimestamp.UnixMilli()) }, nil, conv.LastMessageTimestamp.UnixMilli())
if err != nil { if err != nil {
portal.log.Errorln("Error sending disappearing messages notice event") log.Err(err).Msg("Failed to send disappeared messages notice event")
return return
} }
@ -353,12 +436,18 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
msg.SenderMXID = portal.MainIntent().UserID msg.SenderMXID = portal.MainIntent().UserID
msg.Sent = true msg.Sent = true
msg.Type = database.MsgFake msg.Type = database.MsgFake
msg.Insert(nil) err = msg.Insert(ctx)
user.markSelfReadFull(portal) if err != nil {
log.Err(err).Msg("Failed to save fake message entry for disappearing message timer in backfill")
}
user.markSelfReadFull(ctx, portal)
return return
} }
user.log.Infofln("Backfilling %d messages in %s, %d messages at a time (queue ID: %d)", len(allMsgs), portal.Key.JID, req.MaxBatchEvents, req.QueueID) log.Info().
Int("message_count", len(allMsgs)).
Int("max_batch_events", req.MaxBatchEvents).
Msg("Backfilling messages")
toBackfill := allMsgs[0:] toBackfill := allMsgs[0:]
for len(toBackfill) > 0 { for len(toBackfill) > 0 {
var msgs []*waProto.WebMessageInfo var msgs []*waProto.WebMessageInfo
@ -372,14 +461,14 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
if len(msgs) > 0 { if len(msgs) > 0 {
time.Sleep(time.Duration(req.BatchDelay) * time.Second) time.Sleep(time.Duration(req.BatchDelay) * time.Second)
user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID) log.Debug().Int("batch_message_count", len(msgs)).Msg("Backfilling message batch")
portal.backfill(user, msgs, forward, shouldMarkAsRead) portal.backfill(ctx, user, msgs, forward, shouldMarkAsRead)
} }
} }
user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID) log.Debug().Int("message_count", len(allMsgs)).Msg("Finished backfilling messages in queue entry")
err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs) err = user.bridge.DB.HistorySync.DeleteMessages(ctx, user.MXID, conv.ConversationID, allMsgs)
if err != nil { if err != nil {
user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err) log.Err(err).Msg("Failed to delete history sync messages after backfilling")
} }
if req.TimeStart == nil { if req.TimeStart == nil {
@ -399,8 +488,11 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
// beginning of time. // beginning of time.
backfillState.FirstExpectedTimestamp = 0 backfillState.FirstExpectedTimestamp = 0
} }
backfillState.Upsert() err = backfillState.Upsert(ctx)
portal.updateBackfillStatus(backfillState) if err != nil {
log.Err(err).Msg("Failed to mark backfill state as completed in database")
}
portal.updateBackfillStatus(ctx, backfillState)
} }
} }
@ -408,13 +500,13 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
if evt == nil || evt.SyncType == nil { if evt == nil || evt.SyncType == nil {
return return
} }
log := user.bridge.ZLog.With(). log := user.zlog.With().
Str("method", "User.storeHistorySync"). Str("method", "User.storeHistorySync").
Str("user_id", user.MXID.String()).
Str("sync_type", evt.GetSyncType().String()). Str("sync_type", evt.GetSyncType().String()).
Uint32("chunk_order", evt.GetChunkOrder()). Uint32("chunk_order", evt.GetChunkOrder()).
Uint32("progress", evt.GetProgress()). Uint32("progress", evt.GetProgress()).
Logger() Logger()
ctx := log.WithContext(context.TODO())
if evt.GetGlobalSettings() != nil { if evt.GetGlobalSettings() != nil {
log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync") log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync")
} }
@ -466,7 +558,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues( historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues(
user.MXID, user.MXID,
conv.GetId(), conv.GetId(),
&portal.Key, portal.Key,
getConversationTimestamp(conv), getConversationTimestamp(conv),
conv.GetMuteEndTime(), conv.GetMuteEndTime(),
conv.GetArchived(), conv.GetArchived(),
@ -476,7 +568,10 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
conv.EphemeralExpiration, conv.EphemeralExpiration,
conv.GetMarkedAsUnread(), conv.GetMarkedAsUnread(),
conv.GetUnreadCount()) conv.GetUnreadCount())
historySyncConversation.Upsert() err := historySyncConversation.Upsert(ctx)
if err != nil {
log.Err(err).Msg("Failed to insert history sync conversation into database")
}
} }
var minTime, maxTime time.Time var minTime, maxTime time.Time
@ -521,7 +616,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
Msg("Failed to save historical message") Msg("Failed to save historical message")
continue continue
} }
err = message.Insert() err = message.Insert(ctx)
if err != nil { if err != nil {
log.Error().Err(err). log.Error().Err(err).
Int("msg_index", i). Int("msg_index", i).
@ -570,15 +665,20 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 {
return convTs return convTs
} }
func (user *User) EnqueueImmediateBackfills(portals []*Portal) { func (user *User) EnqueueImmediateBackfills(ctx context.Context, portals []*Portal) {
for priority, portal := range portals { for priority, portal := range portals {
maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents
initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0) initialBackfill := user.bridge.DB.BackfillQueue.NewWithValues(user.MXID, database.BackfillImmediate, priority, portal.Key, nil, maxMessages, maxMessages, 0)
initialBackfill.Insert() err := initialBackfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert immediate backfill into database")
}
} }
} }
func (user *User) EnqueueDeferredBackfills(portals []*Portal) { func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Portal) {
numPortals := len(portals) numPortals := len(portals)
for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred { for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred {
for portalIdx, portal := range portals { for portalIdx, portal := range portals {
@ -587,22 +687,36 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo) startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo)
startDate = &startDaysAgo startDate = &startDaysAgo
} }
backfillMessages := user.bridge.DB.Backfill.NewWithValues( backfillMessages := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay) user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
backfillMessages.Insert() err := backfillMessages.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert deferred backfill into database")
}
} }
} }
} }
func (user *User) EnqueueForwardBackfills(portals []*Portal) { func (user *User) EnqueueForwardBackfills(ctx context.Context, portals []*Portal) {
for priority, portal := range portals { for priority, portal := range portals {
lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key) lastMsg, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if lastMsg == nil { if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to get last message in chat to enqueue forward backfill")
} else if lastMsg == nil {
continue continue
} }
backfill := user.bridge.DB.Backfill.NewWithValues( backfill := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0) user.MXID, database.BackfillForward, priority, portal.Key, &lastMsg.Timestamp, -1, -1, 0)
backfill.Insert() err = backfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert forward backfill into database")
}
} }
} }
@ -619,12 +733,11 @@ func (portal *Portal) deterministicEventID(sender types.JID, messageID types.Mes
} }
var ( var (
PortalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType} BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType}
) )
func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) *mautrix.RespBeeperBatchSend { func (portal *Portal) backfill(ctx context.Context, source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) {
log := zerolog.Ctx(ctx)
var req mautrix.ReqBeeperBatchSend var req mautrix.ReqBeeperBatchSend
var infos []*wrappedInfo var infos []*wrappedInfo
@ -633,7 +746,10 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
req.MarkReadBy = source.MXID req.MarkReadBy = source.MXID
} }
portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward) log.Info().
Bool("forward", isForward).
Int("message_count", len(messages)).
Msg("Processing history sync message batch")
// The messages are ordered newest to oldest, so iterate them in reverse order. // The messages are ordered newest to oldest, so iterate them in reverse order.
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
webMsg := messages[i] webMsg := messages[i]
@ -641,11 +757,16 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
if err != nil { if err != nil {
continue continue
} }
log := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger()
ctx := log.WithContext(ctx)
msgType := getMessageType(msgEvt.Message) msgType := getMessageType(msgEvt.Message)
if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" { if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
if msgType != "ignore" { if msgType != "ignore" {
portal.log.Debugfln("Skipping message %s with unknown type in backfill", msgEvt.Info.ID) log.Debug().Msg("Skipping message with unknown type in backfill")
} }
continue continue
} }
@ -654,85 +775,83 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
if !existingContact.Found || existingContact.PushName == "" { if !existingContact.Found || existingContact.PushName == "" {
changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName()) changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName())
if err != nil { if err != nil {
source.log.Errorfln("Failed to save push name of %s from historical message in device store: %v", msgEvt.Info.Sender, err) log.Err(err).Msg("Failed to save push name from historical message to device store")
} else if changed { } else if changed {
source.log.Debugfln("Got push name %s for %s from historical message", webMsg.GetPushName(), msgEvt.Info.Sender) log.Debug().Str("push_name", webMsg.GetPushName()).Msg("Got push name from historical message")
} }
} }
} }
puppet := portal.getMessagePuppet(source, &msgEvt.Info) puppet := portal.getMessagePuppet(ctx, source, &msgEvt.Info)
if puppet == nil { if puppet == nil {
continue continue
} }
converted := portal.convertMessage(puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true) converted := portal.convertMessage(ctx, puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
if converted == nil { if converted == nil {
portal.log.Debugfln("Skipping unsupported message %s in backfill", msgEvt.Info.ID) log.Debug().Msg("Skipping unsupported message in backfill")
continue continue
} }
if converted.ReplyTo != nil { if converted.ReplyTo != nil {
portal.SetReply(msgEvt.Info.ID, converted.Content, converted.ReplyTo, true) portal.SetReply(ctx, converted.Content, converted.ReplyTo, true)
} }
err = portal.appendBatchEvents(source, converted, &msgEvt.Info, webMsg, &req.Events, &infos) err = portal.appendBatchEvents(ctx, source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
if err != nil { if err != nil {
portal.log.Errorfln("Error handling message %s during backfill: %v", msgEvt.Info.ID, err) log.Err(err).Msg("Failed to handle message in backfill")
} }
} }
portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events)) log.Info().Int("event_count", len(req.Events)).Msg("Made Matrix events from messages in batch")
if len(req.Events) == 0 { if len(req.Events) == 0 {
return nil return
} }
resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &req) resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &req)
if err != nil { if err != nil {
portal.log.Errorln("Error batch sending messages:", err) log.Err(err).Msg("Failed to send batch of messages")
return nil return
} else { }
txn, err := portal.bridge.DB.Begin() err = portal.bridge.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
if err != nil { return portal.finishBatch(ctx, resp.EventIDs, infos)
portal.log.Errorln("Failed to start transaction to save batch messages:", err) })
return nil if err != nil {
} log.Err(err).Msg("Failed to save message batch to database")
return
portal.finishBatch(txn, resp.EventIDs, infos) }
log.Info().Msg("Successfully sent backfill batch")
err = txn.Commit() if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
if err != nil { go portal.requestMediaRetries(context.TODO(), source, resp.EventIDs, infos)
portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
return nil
}
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(source, resp.EventIDs, infos)
}
return resp
} }
} }
func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, infos []*wrappedInfo) { func (portal *Portal) requestMediaRetries(ctx context.Context, source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
for i, info := range infos { for i, info := range infos {
if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil { if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil {
switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod { switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod {
case config.MediaRequestMethodImmediate: case config.MediaRequestMethodImmediate:
err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey) err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to send post-backfill media retry request for %s: %v", info.ID, err) portal.zlog.Err(err).Str("message_id", info.ID).Msg("Failed to send post-backfill media retry request")
} else { } else {
portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID) portal.zlog.Debug().Str("message_id", info.ID).Msg("Sent post-backfill media retry request")
} }
case config.MediaRequestMethodLocalTime: case config.MediaRequestMethodLocalTime:
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey) req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, portal.Key, eventIDs[i], info.MediaKey)
req.Upsert() err := req.Upsert(ctx)
if err != nil {
portal.zlog.Err(err).
Stringer("event_id", eventIDs[i]).
Msg("Failed to upsert media backfill request")
}
} }
} }
} }
} }
func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error { func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
if portal.bridge.Config.Bridge.CaptionInMessage { if portal.bridge.Config.Bridge.CaptionInMessage {
converted.MergeCaption() converted.MergeCaption()
} }
mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "") mainEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
if err != nil { if err != nil {
return err return err
} }
@ -750,7 +869,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
ExpiresIn: converted.ExpiresIn, ExpiresIn: converted.ExpiresIn,
} }
if converted.Caption != nil { if converted.Caption != nil {
captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption") captionEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
if err != nil { if err != nil {
return err return err
} }
@ -762,7 +881,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
} }
if converted.MultiEvent != nil { if converted.MultiEvent != nil {
for i, subEvtContent := range converted.MultiEvent { for i, subEvtContent := range converted.MultiEvent {
subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i)) subEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
if err != nil { if err != nil {
return err return err
} }
@ -771,7 +890,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
} }
} }
for _, reaction := range raw.GetReactions() { for _, reaction := range raw.GetReactions() {
reactionEvent, reactionInfo := portal.wrapBatchReaction(source, reaction, mainEvt.ID, info.Timestamp) reactionEvent, reactionInfo := portal.wrapBatchReaction(ctx, source, reaction, mainEvt.ID, info.Timestamp)
if reactionEvent != nil { if reactionEvent != nil {
*eventsArray = append(*eventsArray, reactionEvent) *eventsArray = append(*eventsArray, reactionEvent)
*infoArray = append(*infoArray, &wrappedInfo{ *infoArray = append(*infoArray, &wrappedInfo{
@ -785,7 +904,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
return nil return nil
} }
func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) { func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
var senderJID types.JID var senderJID types.JID
if reaction.GetKey().GetFromMe() { if reaction.GetKey().GetFromMe() {
senderJID = source.JID.ToNonAD() senderJID = source.JID.ToNonAD()
@ -807,7 +926,7 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
ID: reaction.GetKey().GetId(), ID: reaction.GetKey().GetId(),
Timestamp: mainEventTS, Timestamp: mainEventTS,
} }
puppet := portal.getMessagePuppet(source, reactionInfo) puppet := portal.getMessagePuppet(ctx, source, reactionInfo)
if puppet == nil { if puppet == nil {
return return
} }
@ -834,12 +953,12 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
return return
} }
func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) { func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
wrappedContent := event.Content{ wrappedContent := event.Content{
Parsed: content, Parsed: content,
Raw: extraContent, Raw: extraContent,
} }
newEventType, err := portal.encrypt(intent, &wrappedContent, eventType) newEventType, err := portal.encrypt(ctx, intent, &wrappedContent, eventType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -853,37 +972,37 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
}, nil }, nil
} }
func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, infos []*wrappedInfo) { func (portal *Portal) finishBatch(ctx context.Context, eventIDs []id.EventID, infos []*wrappedInfo) error {
for i, info := range infos { for i, info := range infos {
if info == nil { if info == nil {
continue continue
} }
eventID := eventIDs[i] eventID := eventIDs[i]
portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error) portal.markHandled(ctx, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
if info.Type == database.MsgReaction { if info.Type == database.MsgReaction {
portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID) portal.upsertReaction(ctx, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
} }
if info.ExpiresIn > 0 { if info.ExpiresIn > 0 {
portal.MarkDisappearing(txn, eventID, info.ExpiresIn, info.ExpirationStart) portal.MarkDisappearing(ctx, eventID, info.ExpiresIn, info.ExpirationStart)
} }
} }
portal.log.Infofln("Successfully sent %d events", len(eventIDs)) return nil
} }
func (portal *Portal) updateBackfillStatus(backfillState *database.BackfillState) { func (portal *Portal) updateBackfillStatus(ctx context.Context, backfillState *database.BackfillState) {
backfillStatus := "backfilling" backfillStatus := "backfilling"
if backfillState.BackfillComplete { if backfillState.BackfillComplete {
backfillStatus = "complete" backfillStatus = "complete"
} }
_, err := portal.bridge.Bot.SendStateEvent(portal.MXID, BackfillStatusEvent, "", map[string]interface{}{ _, err := portal.bridge.Bot.SendStateEvent(ctx, portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
"status": backfillStatus, "status": backfillStatus,
"first_timestamp": backfillState.FirstExpectedTimestamp * 1000, "first_timestamp": backfillState.FirstExpectedTimestamp * 1000,
}) })
if err != nil { if err != nil {
portal.log.Errorln("Error sending backfill status event:", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill status event to room")
} }
} }

62
main.go
View file

@ -17,6 +17,7 @@
package main package main
import ( import (
"context"
_ "embed" _ "embed"
"net/http" "net/http"
"net/url" "net/url"
@ -26,15 +27,18 @@ import (
"sync" "sync"
"time" "time"
"github.com/rs/zerolog"
waLog "go.mau.fi/whatsmeow/util/log"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"go.mau.fi/util/configupgrade"
"go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"go.mau.fi/util/configupgrade"
"maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridge/status"
@ -91,7 +95,7 @@ func (br *WABridge) Init() {
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage) br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage) br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
Analytics.log = br.Log.Sub("Analytics") Analytics.log = br.ZLog.With().Str("component", "analytics").Logger()
Analytics.url = (&url.URL{ Analytics.url = (&url.URL{
Scheme: "https", Scheme: "https",
Host: br.Config.Analytics.Host, Host: br.Config.Analytics.Host,
@ -100,23 +104,20 @@ func (br *WABridge) Init() {
Analytics.key = br.Config.Analytics.Token Analytics.key = br.Config.Analytics.Token
Analytics.userID = br.Config.Analytics.UserID Analytics.userID = br.Config.Analytics.UserID
if Analytics.IsEnabled() { if Analytics.IsEnabled() {
Analytics.log.Infoln("Analytics metrics are enabled") Analytics.log.Info().Str("override_user_id", Analytics.userID).Msg("Analytics metrics are enabled")
if Analytics.userID != "" {
Analytics.log.Infoln("Overriding analytics user_id with %v", Analytics.userID)
}
} }
br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database")) br.DB = database.New(br.Bridge.DB)
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), &waLogger{br.Log.Sub("Database").Sub("WhatsApp")}) br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), waLog.Zerolog(br.ZLog.With().Str("db_section", "whatsmeow").Logger()))
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
ss := br.Config.Bridge.Provisioning.SharedSecret ss := br.Config.Bridge.Provisioning.SharedSecret
if len(ss) > 0 && ss != "disable" { if len(ss) > 0 && ss != "disable" {
br.Provisioning = &ProvisioningAPI{bridge: br} br.Provisioning = &ProvisioningAPI{bridge: br, log: br.ZLog.With().Str("component", "provisioning").Logger()}
} }
br.Formatter = NewFormatter(br) br.Formatter = NewFormatter(br)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB) br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.ZLog.With().Str("component", "metrics").Logger(), br.DB)
br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion) store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
@ -148,11 +149,10 @@ func (br *WABridge) Init() {
func (br *WABridge) Start() { func (br *WABridge) Start() {
err := br.WAContainer.Upgrade() err := br.WAContainer.Upgrade()
if err != nil { if err != nil {
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err) br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade whatsmeow database")
os.Exit(15) os.Exit(15)
} }
if br.Provisioning != nil { if br.Provisioning != nil {
br.Log.Debugln("Initializing provisioning API")
br.Provisioning.Init() br.Provisioning.Init()
} }
go br.CheckWhatsAppUpdate() go br.CheckWhatsAppUpdate()
@ -166,30 +166,40 @@ func (br *WABridge) Start() {
} }
func (br *WABridge) CheckWhatsAppUpdate() { func (br *WABridge) CheckWhatsAppUpdate() {
br.Log.Debugfln("Checking for WhatsApp web update") br.ZLog.Debug().Msg("Checking for WhatsApp web update")
resp, err := whatsmeow.CheckUpdate(http.DefaultClient) resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
if err != nil { if err != nil {
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err) br.ZLog.Warn().Err(err).Msg("Failed to check for WhatsApp web update")
return return
} }
if store.GetWAVersion() == resp.ParsedVersion { if store.GetWAVersion() == resp.ParsedVersion {
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol") br.ZLog.Debug().Msg("Bridge is using latest WhatsApp web protocol")
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) { } else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
if resp.IsBelowHard || resp.IsBroken { if resp.IsBelowHard || resp.IsBroken {
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.ZLog.Warn().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore")
} else if resp.IsBelowSoft { } else if resp.IsBelowSoft {
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.ZLog.Info().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
} else { } else {
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.ZLog.Debug().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
} }
} else { } else {
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol") br.ZLog.Debug().Msg("Bridge is using newer than latest WhatsApp web protocol")
} }
} }
func (br *WABridge) Loop() { func (br *WABridge) Loop() {
ctx := br.ZLog.With().Str("action", "background loop").Logger().WithContext(context.TODO())
for { for {
br.SleepAndDeleteUpcoming() br.SleepAndDeleteUpcoming(ctx)
time.Sleep(1 * time.Hour) time.Sleep(1 * time.Hour)
br.WarnUsersAboutDisconnection() br.WarnUsersAboutDisconnection()
} }
@ -199,14 +209,14 @@ func (br *WABridge) WarnUsersAboutDisconnection() {
br.usersLock.Lock() br.usersLock.Lock()
for _, user := range br.usersByUsername { for _, user := range br.usersByUsername {
if user.IsConnected() && !user.PhoneRecentlySeen(true) { if user.IsConnected() && !user.PhoneRecentlySeen(true) {
go user.sendPhoneOfflineWarning() go user.sendPhoneOfflineWarning(context.TODO())
} }
} }
br.usersLock.Unlock() br.usersLock.Unlock()
} }
func (br *WABridge) StartUsers() { func (br *WABridge) StartUsers() {
br.Log.Debugln("Starting users") br.ZLog.Debug().Msg("Starting users")
foundAnySessions := false foundAnySessions := false
for _, user := range br.GetAllUsers() { for _, user := range br.GetAllUsers() {
if !user.JID.IsEmpty() { if !user.JID.IsEmpty() {
@ -217,13 +227,13 @@ func (br *WABridge) StartUsers() {
if !foundAnySessions { if !foundAnySessions {
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil)) br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil))
} }
br.Log.Debugln("Starting custom puppets") br.ZLog.Debug().Msg("Starting custom puppets")
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() { for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
go func(puppet *Puppet) { go func(puppet *Puppet) {
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID) puppet.zlog.Debug().Stringer("custom_mxid", puppet.CustomMXID).Msg("Starting double puppet")
err := puppet.StartCustomMXID(true) err := puppet.StartCustomMXID(true)
if err != nil { if err != nil {
puppet.log.Errorln("Failed to start custom puppet:", err) puppet.zlog.Err(err).Stringer("custom_mxid", puppet.CustomMXID).Msg("Failed to start double puppet")
} }
}(loopuppet) }(loopuppet)
} }
@ -235,7 +245,7 @@ func (br *WABridge) Stop() {
if user.Client == nil { if user.Client == nil {
continue continue
} }
br.Log.Debugln("Disconnecting", user.MXID) user.zlog.Debug().Msg("Disconnecting user")
user.Client.Disconnect() user.Client.Disconnect()
close(user.historySyncs) close(user.historySyncs)
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,10 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
@ -35,77 +37,89 @@ func (br *WABridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.User,
puppet := brGhost.(*Puppet) puppet := brGhost.(*Puppet)
key := database.NewPortalKey(puppet.JID, inviter.JID) key := database.NewPortalKey(puppet.JID, inviter.JID)
portal := br.GetPortalByJID(key) portal := br.GetPortalByJID(key)
log := br.ZLog.With().
Str("action", "create private portal").
Stringer("target_room_id", roomID).
Stringer("inviter_mxid", inviter.MXID).
Stringer("invitee_jid", puppet.JID).
Logger()
ctx := log.WithContext(context.TODO())
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal) br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
return return
} }
ok := portal.ensureUserInvited(inviter) ok := portal.ensureUserInvited(ctx, inviter)
if !ok { if !ok {
br.Log.Warnfln("Failed to invite %s to existing private chat portal %s with %s. Redirecting portal to new room...", inviter.MXID, portal.MXID, puppet.JID) log.Warn().Msg("Failed to invite user to existing private chat portal. Redirecting portal to new room...")
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal) br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
return return
} }
intent := puppet.DefaultIntent() intent := puppet.DefaultIntent()
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%[1]s](https://matrix.to/#/%[1]s)", portal.MXID) errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Config.Homeserver.Domain).MatrixToURL())
errorContent := format.RenderMarkdown(errorMessage, true, false) errorContent := format.RenderMarkdown(errorMessage, true, false)
_, _ = intent.SendMessageEvent(roomID, event.EventMessage, errorContent) _, _ = intent.SendMessageEvent(ctx, roomID, event.EventMessage, errorContent)
br.Log.Debugfln("Leaving private chat room %s as %s after accepting invite from %s as we already have chat with the user", roomID, puppet.MXID, inviter.MXID) log.Debug().Msg("Leaving private chat room from invite as we already have chat with the user")
_, _ = intent.LeaveRoom(roomID) _, _ = intent.LeaveRoom(ctx, roomID)
} }
func (br *WABridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) { func (br *WABridge) createPrivatePortalFromInvite(ctx context.Context, roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
log := zerolog.Ctx(ctx)
// TODO check if room is already encrypted // TODO check if room is already encrypted
var existingEncryption event.EncryptionEventContent var existingEncryption event.EncryptionEventContent
var encryptionEnabled bool var encryptionEnabled bool
err := portal.MainIntent().StateEvent(roomID, event.StateEncryption, "", &existingEncryption) err := portal.MainIntent().StateEvent(ctx, roomID, event.StateEncryption, "", &existingEncryption)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to check if encryption is enabled in private chat room %s", roomID) log.Err(err).Msg("Failed to check if encryption is enabled")
} else { } else {
encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1 encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1
} }
portal.MXID = roomID portal.MXID = roomID
portal.updateLogger()
portal.Topic = PrivateChatTopic portal.Topic = PrivateChatTopic
portal.Name = puppet.Displayname portal.Name = puppet.Displayname
portal.AvatarURL = puppet.AvatarURL portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar portal.Avatar = puppet.Avatar
portal.log.Infofln("Created private chat portal in %s after invite from %s", roomID, inviter.MXID) log.Info().Msg("Created private chat portal from invite")
intent := puppet.DefaultIntent() intent := puppet.DefaultIntent()
if br.Config.Bridge.Encryption.Default || encryptionEnabled { if br.Config.Bridge.Encryption.Default || encryptionEnabled {
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID}) _, err = intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
if err != nil { if err != nil {
portal.log.Warnln("Failed to invite bridge bot to enable e2be:", err) log.Err(err).Msg("Failed to invite bridge bot to enable e2be")
} }
err = br.Bot.EnsureJoined(roomID) err = br.Bot.EnsureJoined(ctx, roomID)
if err != nil { if err != nil {
portal.log.Warnln("Failed to join as bridge bot to enable e2be:", err) log.Err(err).Msg("Failed to join as bridge bot to enable e2be")
} }
if !encryptionEnabled { if !encryptionEnabled {
_, err = intent.SendStateEvent(roomID, event.StateEncryption, "", portal.GetEncryptionEventContent()) _, err = intent.SendStateEvent(ctx, roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil { if err != nil {
portal.log.Warnln("Failed to enable e2be:", err) log.Err(err).Msg("Failed to enable e2be")
} }
} }
br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin) br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin) br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin) br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin)
portal.Encrypted = true portal.Encrypted = true
} }
_, _ = portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic) _, _ = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, portal.Topic)
if portal.shouldSetDMRoomMetadata() { if portal.shouldSetDMRoomMetadata() {
_, err = portal.MainIntent().SetRoomName(portal.MXID, portal.Name) _, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, portal.Name)
portal.NameSet = err == nil portal.NameSet = err == nil
_, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) _, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL)
portal.AvatarSet = err == nil portal.AvatarSet = err == nil
} }
portal.Update(nil) err = portal.Update(ctx)
portal.UpdateBridgeInfo() if err != nil {
_, _ = intent.SendNotice(roomID, "Private chat portal created") log.Err(err).Msg("Failed to save portal to database after creating from invite")
}
portal.UpdateBridgeInfo(ctx)
_, _ = intent.SendNotice(ctx, roomID, "Private chat portal created")
} }
func (br *WABridge) HandlePresence(evt *event.Event) { func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) {
user := br.GetUserByMXIDIfExists(evt.Sender) user := br.GetUserByMXIDIfExists(evt.Sender)
if user == nil || !user.IsLoggedIn() { if user == nil || !user.IsLoggedIn() {
return return
@ -119,15 +133,15 @@ func (br *WABridge) HandlePresence(evt *event.Event) {
presence := types.PresenceAvailable presence := types.PresenceAvailable
if evt.Content.AsPresence().Presence != event.PresenceOnline { if evt.Content.AsPresence().Presence != event.PresenceOnline {
presence = types.PresenceUnavailable presence = types.PresenceUnavailable
user.log.Debugln("Marking offline") user.zlog.Debug().Msg("Marking offline")
} else { } else {
user.log.Debugln("Marking online") user.zlog.Debug().Msg("Marking online")
} }
user.lastPresence = presence user.lastPresence = presence
if user.Client.Store.PushName != "" { if user.Client.Store.PushName != "" {
err := user.Client.SendPresence(presence) err := user.Client.SendPresence(presence)
if err != nil { if err != nil {
user.log.Warnln("Failed to set presence:", err) user.zlog.Err(err).Msg("Failed to set presence")
} }
} }
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -23,7 +23,7 @@ import (
"sync" "sync"
"time" "time"
log "maunium.net/go/maulogger/v2" "github.com/rs/zerolog"
"go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow"
@ -123,7 +123,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
} }
} }
func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType string, confirmed bool, editID id.EventID) id.EventID { func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, err error, confirmed bool, editID id.EventID) id.EventID {
if !portal.bridge.Config.Bridge.MessageErrorNotices { if !portal.bridge.Config.Bridge.MessageErrorNotices {
return "" return ""
} }
@ -131,6 +131,21 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
if confirmed { if confirmed {
certainty = "was not" certainty = "was not"
} }
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err) msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err)
if errors.Is(err, errMessageTakingLong) { if errors.Is(err, errMessageTakingLong) {
msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType) msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType)
@ -144,15 +159,15 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
} else { } else {
content.SetReply(evt) content.SetReply(evt)
} }
resp, err := portal.sendMainIntentMessage(content) resp, err := portal.sendMainIntentMessage(ctx, content)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to send bridging error message:", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to send bridging error message")
return "" return ""
} }
return resp.EventID return resp.EventID
} }
func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) { func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
if !portal.bridge.Config.Bridge.MessageStatusEvents { if !portal.bridge.Config.Bridge.MessageStatusEvents {
return return
} }
@ -179,75 +194,56 @@ func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, de
content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err) content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err)
content.Error = err.Error() content.Error = err.Error()
} }
_, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content) _, err = intent.SendMessageEvent(ctx, portal.MXID, event.BeeperMessageStatus, &content)
if err != nil { if err != nil {
portal.log.Warnln("Failed to send message status event:", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to send message status event")
} }
} }
func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) { func (portal *Portal) sendDeliveryReceipt(ctx context.Context, eventID id.EventID) {
if portal.bridge.Config.Bridge.DeliveryReceipts { if portal.bridge.Config.Bridge.DeliveryReceipts {
err := portal.bridge.Bot.SendReceipt(portal.MXID, eventID, event.ReceiptTypeRead, nil) err := portal.bridge.Bot.SendReceipt(ctx, portal.MXID, eventID, event.ReceiptTypeRead, nil)
if err != nil { if err != nil {
portal.log.Debugfln("Failed to send delivery receipt for %s: %v", eventID, err) zerolog.Ctx(ctx).Err(err).Msg("Failed to mark message as read by bot (Matrix-side delivery receipt)")
} }
} }
} }
func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string, ms *metricSender) { func (portal *Portal) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, ms *metricSender) {
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
evtDescription := evt.ID.String()
if evt.Type == event.EventRedaction {
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
}
origEvtID := evt.ID origEvtID := evt.ID
if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil { if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil {
origEvtID = retryMeta.OriginalEventID origEvtID = retryMeta.OriginalEventID
} }
if err != nil { if err != nil {
level := log.LevelError level := zerolog.ErrorLevel
if part == "Ignoring" { if part == "Ignoring" {
level = log.LevelDebug level = zerolog.DebugLevel
} }
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err) zerolog.Ctx(ctx).WithLevel(level).Err(err).Msg(part + " Matrix event")
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err) reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode) checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode)
portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum()) portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum())
if sendNotice { if sendNotice {
ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID())) ms.setNoticeID(portal.sendErrorMessage(ctx, evt, err, isCertain, ms.getNoticeID()))
} }
portal.sendStatusEvent(origEvtID, evt.ID, err, nil) portal.sendStatusEvent(ctx, origEvtID, evt.ID, err, nil)
} else { } else {
portal.log.Debugfln("Handled Matrix %s %s", msgType, evtDescription) zerolog.Ctx(ctx).Debug().Msg("Successfully handled Matrix event")
portal.sendDeliveryReceipt(evt.ID) portal.sendDeliveryReceipt(ctx, evt.ID)
portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum()) portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum())
var deliveredTo *[]id.UserID var deliveredTo *[]id.UserID
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
deliveredTo = &[]id.UserID{} deliveredTo = &[]id.UserID{}
} }
portal.sendStatusEvent(origEvtID, evt.ID, nil, deliveredTo) portal.sendStatusEvent(ctx, origEvtID, evt.ID, nil, deliveredTo)
if prevNotice := ms.popNoticeID(); prevNotice != "" { if prevNotice := ms.popNoticeID(); prevNotice != "" {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{ _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, prevNotice, mautrix.ReqRedact{
Reason: "error resolved", Reason: "error resolved",
}) })
} }
} }
if ms != nil { if ms != nil {
portal.log.Debugfln("Timings for %s: %s", evt.ID, ms.timings.String()) zerolog.Ctx(ctx).Debug().Object("timings", ms.timings).Msg("Matrix event timings")
} }
} }
@ -264,47 +260,16 @@ type messageTimings struct {
totalSend time.Duration totalSend time.Duration
} }
func niceRound(dur time.Duration) time.Duration { func (mt *messageTimings) MarshalZerologObject(e *zerolog.Event) {
switch { e.Dur("init_receive", mt.initReceive).
case dur < time.Millisecond: Dur("decrypt", mt.decrypt).
return dur Dur("implicit_rr", mt.implicitRR).
case dur < time.Second: Dur("portal_queue", mt.portalQueue).
return dur.Round(100 * time.Microsecond) Dur("total_receive", mt.totalReceive).
default: Dur("preproc", mt.preproc).
return dur.Round(time.Millisecond) Dur("convert", mt.convert).
} Object("whatsmeow", mt.whatsmeow).
} Dur("total_send", mt.totalSend)
func (mt *messageTimings) String() string {
mt.initReceive = niceRound(mt.initReceive)
mt.decrypt = niceRound(mt.decrypt)
mt.portalQueue = niceRound(mt.portalQueue)
mt.totalReceive = niceRound(mt.totalReceive)
mt.implicitRR = niceRound(mt.implicitRR)
mt.preproc = niceRound(mt.preproc)
mt.convert = niceRound(mt.convert)
mt.whatsmeow.Queue = niceRound(mt.whatsmeow.Queue)
mt.whatsmeow.Marshal = niceRound(mt.whatsmeow.Marshal)
mt.whatsmeow.GetParticipants = niceRound(mt.whatsmeow.GetParticipants)
mt.whatsmeow.GetDevices = niceRound(mt.whatsmeow.GetDevices)
mt.whatsmeow.GroupEncrypt = niceRound(mt.whatsmeow.GroupEncrypt)
mt.whatsmeow.PeerEncrypt = niceRound(mt.whatsmeow.PeerEncrypt)
mt.whatsmeow.Send = niceRound(mt.whatsmeow.Send)
mt.whatsmeow.Resp = niceRound(mt.whatsmeow.Resp)
mt.whatsmeow.Retry = niceRound(mt.whatsmeow.Retry)
mt.totalSend = niceRound(mt.totalSend)
whatsmeowTimings := "N/A"
if mt.totalSend > 0 {
format := "queue: %[1]s, marshal: %[2]s, ske: %[3]s, pcp: %[4]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
if mt.whatsmeow.GetParticipants == 0 && mt.whatsmeow.GroupEncrypt == 0 {
format = "queue: %[1]s, marshal: %[2]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
}
if mt.whatsmeow.Retry > 0 {
format += ", retry: %[9]s"
}
whatsmeowTimings = fmt.Sprintf(format, mt.whatsmeow.Queue, mt.whatsmeow.Marshal, mt.whatsmeow.GroupEncrypt, mt.whatsmeow.GetParticipants, mt.whatsmeow.GetDevices, mt.whatsmeow.PeerEncrypt, mt.whatsmeow.Send, mt.whatsmeow.Resp, mt.whatsmeow.Retry)
}
return fmt.Sprintf("BRIDGE: receive: %s, decrypt: %s, queue: %s, total hs->portal: %s, implicit rr: %s -- PORTAL: preprocess: %s, convert: %s, total send: %s -- WHATSMEOW: %s", mt.initReceive, mt.decrypt, mt.implicitRR, mt.portalQueue, mt.totalReceive, mt.preproc, mt.convert, mt.totalSend, whatsmeowTimings)
} }
type metricSender struct { type metricSender struct {
@ -345,13 +310,13 @@ func (ms *metricSender) setNoticeID(evtID id.EventID) {
} }
} }
func (ms *metricSender) sendMessageMetrics(evt *event.Event, err error, part string, completed bool) { func (ms *metricSender) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, completed bool) {
ms.lock.Lock() ms.lock.Lock()
defer ms.lock.Unlock() defer ms.lock.Unlock()
if !completed && ms.completed { if !completed && ms.completed {
return return
} }
ms.portal.sendMessageMetrics(evt, err, part, ms) ms.portal.sendMessageMetrics(ctx, evt, err, part, ms)
ms.retryNum++ ms.retryNum++
ms.completed = completed ms.completed = completed
} }

View file

@ -18,6 +18,7 @@ package main
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"strconv" "strconv"
@ -27,7 +28,7 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
log "maunium.net/go/maulogger/v2" "github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
@ -40,7 +41,7 @@ import (
type MetricsHandler struct { type MetricsHandler struct {
db *database.Database db *database.Database
server *http.Server server *http.Server
log log.Logger log zerolog.Logger
running bool running bool
ctx context.Context ctx context.Context
@ -70,7 +71,7 @@ type MetricsHandler struct {
loggedInStateLock sync.Mutex loggedInStateLock sync.Mutex
} }
func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler { func NewMetricsHandler(address string, log zerolog.Logger, db *database.Database) *MetricsHandler {
portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{ portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "whatsapp_portals_total", Name: "whatsapp_portals_total",
Help: "Number of portal rooms on Matrix", Help: "Number of portal rooms on Matrix",
@ -232,31 +233,31 @@ func (mh *MetricsHandler) TrackConnectionState(jid types.JID, connected bool) {
func (mh *MetricsHandler) updateStats() { func (mh *MetricsHandler) updateStats() {
start := time.Now() start := time.Now()
var puppetCount int var puppetCount int
err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount) err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
if err != nil { if err != nil {
mh.log.Warnln("Failed to scan number of puppets:", err) mh.log.Err(err).Msg("Failed to scan number of puppets")
} else { } else {
mh.puppetCount.Set(float64(puppetCount)) mh.puppetCount.Set(float64(puppetCount))
} }
var userCount int var userCount int
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount) err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
if err != nil { if err != nil {
mh.log.Warnln("Failed to scan number of users:", err) mh.log.Err(err).Msg("Failed to scan number of users")
} else { } else {
mh.userCount.Set(float64(userCount)) mh.userCount.Set(float64(userCount))
} }
var messageCount int var messageCount int
err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount) err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
if err != nil { if err != nil {
mh.log.Warnln("Failed to scan number of messages:", err) mh.log.Err(err).Msg("Failed to scan number of messages")
} else { } else {
mh.messageCount.Set(float64(messageCount)) mh.messageCount.Set(float64(messageCount))
} }
var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int
err = mh.db.QueryRowContext(mh.ctx, ` err = mh.db.QueryRow(mh.ctx, `
SELECT SELECT
COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals, COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals,
COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals, COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals,
@ -265,7 +266,7 @@ func (mh *MetricsHandler) updateStats() {
FROM portal WHERE mxid<>'' FROM portal WHERE mxid<>''
`).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount) `).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount)
if err != nil { if err != nil {
mh.log.Warnln("Failed to scan number of portals:", err) mh.log.Err(err).Msg("Failed to scan number of portals")
} else { } else {
mh.encryptedGroupCount.Set(float64(encryptedGroupCount)) mh.encryptedGroupCount.Set(float64(encryptedGroupCount))
mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount)) mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount))
@ -279,7 +280,10 @@ func (mh *MetricsHandler) startUpdatingStats() {
defer func() { defer func() {
err := recover() err := recover()
if err != nil { if err != nil {
mh.log.Fatalfln("Panic in metric updater: %v\n%s", err, string(debug.Stack())) mh.log.WithLevel(zerolog.PanicLevel).
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
Interface(zerolog.ErrorFieldName, err).
Msg("Panic in metric updater")
} }
}() }()
ticker := time.Tick(10 * time.Second) ticker := time.Tick(10 * time.Second)
@ -299,8 +303,8 @@ func (mh *MetricsHandler) Start() {
go mh.startUpdatingStats() go mh.startUpdatingStats()
err := mh.server.ListenAndServe() err := mh.server.ListenAndServe()
mh.running = false mh.running = false
if err != nil && err != http.ErrServerClosed { if err != nil && !errors.Is(err, http.ErrServerClosed) {
mh.log.Fatalln("Error in metrics listener:", err) mh.log.Err(err).Msg("Error in metrics listener")
} }
} }
@ -311,6 +315,6 @@ func (mh *MetricsHandler) Stop() {
mh.stopRecorder() mh.stopRecorder()
err := mh.server.Close() err := mh.server.Close()
if err != nil { if err != nil {
mh.log.Errorln("Error closing metrics listener:", err) mh.log.Err(err).Msg("Failed to close metrics listener")
} }
} }

1991
portal.go

File diff suppressed because it is too large Load diff

View file

@ -17,41 +17,40 @@
package main package main
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"strings" "strings"
"time" "time"
"github.com/beeper/libserv/pkg/requestlog"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/whatsmeow/appstate" "go.mau.fi/whatsmeow/appstate"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
type ProvisioningAPI struct { type ProvisioningAPI struct {
bridge *WABridge bridge *WABridge
log log.Logger log zerolog.Logger
} }
func (prov *ProvisioningAPI) Init() { func (prov *ProvisioningAPI) Init() {
prov.log = prov.bridge.Log.Sub("Provisioning") prov.log.Debug().Str("base_path", prov.bridge.Config.Bridge.Provisioning.Prefix).Msg("Enabling provisioning API")
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter() r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
r.Use(hlog.NewHandler(prov.log))
r.Use(requestlog.AccessLogger(true))
r.Use(prov.AuthMiddleware) r.Use(prov.AuthMiddleware)
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet) r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet) r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@ -73,7 +72,7 @@ func (prov *ProvisioningAPI) Init() {
prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost) prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost)
if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints { if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints {
prov.log.Debugln("Enabling debug API at /debug") prov.log.Debug().Msg("Enabling debug API at /debug")
r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter() r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.AuthMiddleware) r.Use(prov.AuthMiddleware)
r.PathPrefix("/pprof").Handler(http.DefaultServeMux) r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
@ -83,26 +82,6 @@ func (prov *ProvisioningAPI) Init() {
r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost) r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost)
} }
type responseWrap struct {
http.ResponseWriter
statusCode int
}
var _ http.Hijacker = (*responseWrap)(nil)
func (rw *responseWrap) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}
func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("response does not implement http.Hijacker")
}
return hijacker.Hijack()
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
@ -119,7 +98,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
auth = auth[len("Bearer "):] auth = auth[len("Bearer "):]
} }
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
prov.log.Infof("Authentication token does not match shared secret") hlog.FromRequest(r).Debug().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, map[string]interface{}{ jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Authentication token does not match shared secret", "error": "Authentication token does not match shared secret",
"errcode": "M_FORBIDDEN", "errcode": "M_FORBIDDEN",
@ -128,11 +107,12 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
} }
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID)) user := prov.bridge.GetUserByMXID(id.UserID(userID))
start := time.Now() if user != nil {
wWrap := &responseWrap{w, 200} hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) return c.Stringer("user_id", user.MXID)
duration := time.Now().Sub(start).Seconds() })
prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode) }
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
}) })
} }
@ -157,7 +137,7 @@ func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Reques
return return
} }
user.DeleteConnection() user.DeleteConnection()
user.DeleteSession() user.DeleteSession(r.Context())
jsonResponse(w, http.StatusOK, Response{true, "Session information purged"}) jsonResponse(w, http.StatusOK, Response{true, "Session information purged"})
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
} }
@ -245,7 +225,7 @@ func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request
ErrCode: "no session", ErrCode: "no session",
}) })
} else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil { } else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil {
prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err) hlog.FromRequest(r).Err(err).Msg("Failed to fetch all contacts")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching contact list", Error: "Internal server error while fetching contact list",
ErrCode: "failed to get contacts", ErrCode: "failed to get contacts",
@ -282,7 +262,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true") err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true")
if err != nil { if err != nil {
prov.log.Errorfln("Failed to resync %s's groups: %v", user.MXID, err) hlog.FromRequest(r).Err(err).Msg("Failed to resync groups")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while resyncing groups", Error: "Internal server error while resyncing groups",
ErrCode: "failed to sync groups", ErrCode: "failed to sync groups",
@ -291,7 +271,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
} }
} }
if groups, err := user.getCachedGroupList(); err != nil { if groups, err := user.getCachedGroupList(); err != nil {
prov.log.Errorfln("Failed to fetch %s's groups: %v", user.MXID, err) hlog.FromRequest(r).Err(err).Msg("Failed to fetch group list")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching group list", Error: "Internal server error while fetching group list",
ErrCode: "failed to get groups", ErrCode: "failed to get groups",
@ -368,17 +348,17 @@ func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
// resolveIdentifier already responded with an error // resolveIdentifier already responded with an error
return return
} }
portal, puppet, justCreated, err := user.StartPM(jid, "provisioning API PM") portal, puppet, justCreated, err := user.StartPM(r.Context(), jid, "provisioning API PM")
if err != nil { if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err), Error: fmt.Sprintf("Failed to create portal: %v", err),
}) })
} }
status := http.StatusOK statusCode := http.StatusOK
if justCreated { if justCreated {
status = http.StatusCreated statusCode = http.StatusCreated
} }
jsonResponse(w, status, PortalInfo{ jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID, RoomID: portal.MXID,
OtherUser: &OtherUserInfo{ OtherUser: &OtherUserInfo{
JID: puppet.JID, JID: puppet.JID,
@ -449,29 +429,30 @@ func (prov *ProvisioningAPI) OpenGroup(w http.ResponseWriter, r *http.Request) {
ErrCode: "invalid group id", ErrCode: "invalid group id",
}) })
} else if info, err := user.Client.GetGroupInfo(jid); err != nil { } else if info, err := user.Client.GetGroupInfo(jid); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info by JID")
// TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup) // TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup)
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to get group info: %v", err), Error: fmt.Sprintf("Failed to get group info: %v", err),
ErrCode: "error getting group info", ErrCode: "error getting group info",
}) })
} else { } else {
prov.log.Debugln("Importing", jid, "for", user.MXID) hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Importing group chat for user")
portal := user.GetPortalByJID(info.JID) portal := user.GetPortalByJID(info.JID)
status := http.StatusOK statusCode := http.StatusOK
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
err = portal.CreateMatrixRoom(user, info, nil, true, true) err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
if err != nil { if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err), Error: fmt.Sprintf("Failed to create portal: %v", err),
}) })
return return
} }
status = http.StatusCreated statusCode = http.StatusCreated
} }
jsonResponse(w, status, PortalInfo{ jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID, RoomID: portal.MXID,
GroupInfo: info, GroupInfo: info,
JustCreated: status == http.StatusCreated, JustCreated: statusCode == http.StatusCreated,
}) })
} }
} }
@ -495,6 +476,7 @@ func (prov *ProvisioningAPI) resolveGroupInvite(w http.ResponseWriter, r *http.R
ErrCode: "invalid invite link", ErrCode: "invalid invite link",
}) })
} else { } else {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info from link")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to fetch group info with link: %v", err), Error: fmt.Sprintf("Failed to fetch group info with link: %v", err),
ErrCode: "error getting group info", ErrCode: "error getting group info",
@ -530,29 +512,30 @@ func (prov *ProvisioningAPI) JoinGroup(w http.ResponseWriter, r *http.Request) {
}() }()
inviteCode, _ := mux.Vars(r)["inviteCode"] inviteCode, _ := mux.Vars(r)["inviteCode"]
if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil { if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to join group")
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to join group: %v", err), Error: fmt.Sprintf("Failed to join group: %v", err),
ErrCode: "error joining group", ErrCode: "error joining group",
}) })
} else { } else {
prov.log.Debugln(user.MXID, "successfully joined group", jid) hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Successfully joined group")
portal := user.GetPortalByJID(jid) portal := user.GetPortalByJID(jid)
status := http.StatusOK statusCode := http.StatusOK
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically
err = portal.CreateMatrixRoom(user, info, nil, true, true) err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
if err != nil { if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err), Error: fmt.Sprintf("Failed to create portal: %v", err),
}) })
return return
} }
status = http.StatusCreated statusCode = http.StatusCreated
} }
jsonResponse(w, status, PortalInfo{ jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID, RoomID: portal.MXID,
GroupInfo: info, GroupInfo: info,
JustCreated: status == http.StatusCreated, JustCreated: statusCode == http.StatusCreated,
}) })
} }
} }
@ -616,7 +599,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
} else { } else {
err := user.Client.Logout() err := user.Client.Logout()
if err != nil { if err != nil {
user.log.Warnln("Error while logging out:", err) hlog.FromRequest(r).Err(err).Msg("Unknown error while logging out")
if !force { if !force {
jsonResponse(w, http.StatusInternalServerError, Error{ jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Unknown error while logging out: %v", err), Error: fmt.Sprintf("Unknown error while logging out: %v", err),
@ -632,7 +615,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
user.bridge.Metrics.TrackConnectionState(user.JID, false) user.bridge.Metrics.TrackConnectionState(user.JID, false)
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
user.DeleteSession() user.DeleteSession(r.Context())
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
} }
@ -646,16 +629,17 @@ var upgrader = websocket.Upgrader{
func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID)) user := prov.bridge.GetUserByMXID(id.UserID(userID))
log := hlog.FromRequest(r)
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
prov.log.Errorln("Failed to upgrade connection to websocket:", err) log.Err(err).Msg("Failed to upgrade connection to websocket")
return return
} }
defer func() { defer func() {
err := c.Close() err := c.Close()
if err != nil { if err != nil {
user.log.Debugln("Error closing websocket:", err) log.Debug().Err(err).Msg("Error closing websocket")
} }
}() }()
@ -670,23 +654,26 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
}() }()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c.SetCloseHandler(func(code int, text string) error { c.SetCloseHandler(func(code int, text string) error {
user.log.Debugfln("Login websocket closed (%d), cancelling login", code) log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login")
cancel() cancel()
return nil return nil
}) })
if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" { if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" {
user.log.Debug("Setting timezone to %s", userTimezone) log.Debug().Str("timezone", userTimezone).Msg("Updating user timezone")
user.Timezone = userTimezone user.Timezone = userTimezone
user.Update() err = user.Update(r.Context())
if err != nil {
log.Err(err).Msg("Failed to save user after updating timezone")
}
} else { } else {
user.log.Debug("No timezone provided in request") log.Debug().Msg("No timezone provided in request")
} }
qrChan, err := user.Login(ctx) qrChan, err := user.Login(ctx)
expiryTime := time.Now().Add(160 * time.Second) expiryTime := time.Now().Add(160 * time.Second)
if err != nil { if err != nil {
user.log.Errorln("Failed to log in from provisioning API:", err) log.Err(err).Msg("Failed to log in via provisioning API")
if errors.Is(err, ErrAlreadyLoggedIn) { if errors.Is(err, ErrAlreadyLoggedIn) {
go user.Connect() go user.Connect()
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
@ -704,7 +691,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
if phoneNum != "" { if phoneNum != "" {
pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)") pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)")
if err != nil { if err != nil {
user.zlog.Err(err).Msg("Failed to start phone code login") log.Err(err).Msg("Failed to start phone code login")
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "Failed to request pairing code", Error: "Failed to request pairing code",
ErrCode: "code error", ErrCode: "code error",
@ -712,6 +699,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
go user.DeleteConnection() go user.DeleteConnection()
return return
} else { } else {
log.Debug().Msg("Started phone number login")
_ = c.WriteJSON(map[string]any{ _ = c.WriteJSON(map[string]any{
"pairing_code": pairingCode, "pairing_code": pairingCode,
"timeout": int(time.Until(expiryTime).Seconds()), "timeout": int(time.Until(expiryTime).Seconds()),
@ -719,7 +707,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
} }
} }
user.log.Debugln("Started login via provisioning API") log.Debug().Msg("Started login via provisioning API")
Analytics.Track(user.MXID, "$login_start") Analytics.Track(user.MXID, "$login_start")
for { for {
@ -728,7 +716,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
switch evt.Event { switch evt.Event {
case whatsmeow.QRChannelSuccess.Event: case whatsmeow.QRChannelSuccess.Event:
jid := user.Client.Store.ID jid := user.Client.Store.ID
user.log.Debugln("Successful login as", jid, "via provisioning API") log.Debug().Stringer("jid", jid).Msg("Successful login via provisioning API")
Analytics.Track(user.MXID, "$login_success") Analytics.Track(user.MXID, "$login_success")
_ = c.WriteJSON(map[string]interface{}{ _ = c.WriteJSON(map[string]interface{}{
"success": true, "success": true,
@ -737,7 +725,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
"platform": user.Client.Store.Platform, "platform": user.Client.Store.Platform,
}) })
case whatsmeow.QRChannelTimeout.Event: case whatsmeow.QRChannelTimeout.Event:
user.log.Debugln("Login via provisioning API timed out") log.Debug().Msg("Login via provisioning API timed out")
errCode := "login timed out" errCode := "login timed out"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
@ -745,7 +733,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode, ErrCode: errCode,
}) })
case whatsmeow.QRChannelErrUnexpectedEvent.Event: case whatsmeow.QRChannelErrUnexpectedEvent.Event:
user.log.Debugln("Login via provisioning API failed due to unexpected event") log.Debug().Msg("Login via provisioning API failed due to unexpected event")
errCode := "unexpected event" errCode := "unexpected event"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
@ -753,7 +741,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode, ErrCode: errCode,
}) })
case whatsmeow.QRChannelClientOutdated.Event: case whatsmeow.QRChannelClientOutdated.Event:
user.log.Debugln("Login via provisioning API failed due to outdated client") log.Debug().Msg("Login via provisioning API failed due to outdated client")
errCode := "bridge outdated" errCode := "bridge outdated"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{

137
puppet.go
View file

@ -17,15 +17,15 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"regexp" "regexp"
"sync" "sync"
"time" "time"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge"
@ -59,6 +59,7 @@ func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
} }
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet { func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
ctx := context.TODO()
jid = jid.ToNonAD() jid = jid.ToNonAD()
if jid.Server == types.LegacyUserServer { if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer jid.Server = types.DefaultUserServer
@ -69,11 +70,19 @@ func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
defer br.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
puppet, ok := br.puppets[jid] puppet, ok := br.puppets[jid]
if !ok { if !ok {
dbPuppet := br.DB.Puppet.Get(jid) dbPuppet, err := br.DB.Puppet.Get(ctx, jid)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get puppet from database")
return nil
}
if dbPuppet == nil { if dbPuppet == nil {
dbPuppet = br.DB.Puppet.New() dbPuppet = br.DB.Puppet.New()
dbPuppet.JID = jid dbPuppet.JID = jid
dbPuppet.Insert() err = dbPuppet.Insert(ctx)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to insert new puppet to database")
return nil
}
} }
puppet = br.NewPuppet(dbPuppet) puppet = br.NewPuppet(dbPuppet)
br.puppets[puppet.JID] = puppet br.puppets[puppet.JID] = puppet
@ -89,7 +98,10 @@ func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
defer br.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
puppet, ok := br.puppetsByCustomMXID[mxid] puppet, ok := br.puppetsByCustomMXID[mxid]
if !ok { if !ok {
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid) dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid)
if err != nil {
br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get puppet by custom mxid from database")
}
if dbPuppet == nil { if dbPuppet == nil {
return nil return nil
} }
@ -137,14 +149,18 @@ func (puppet *Puppet) GetMXID() id.UserID {
} }
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet { func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID()) return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID(context.TODO()))
} }
func (br *WABridge) GetAllPuppets() []*Puppet { func (br *WABridge) GetAllPuppets() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll()) return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll(context.TODO()))
} }
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet, err error) []*Puppet {
if err != nil {
br.ZLog.Err(err).Msg("Error getting puppets from database")
return nil
}
br.puppetsLock.Lock() br.puppetsLock.Lock()
defer br.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
output := make([]*Puppet, len(dbPuppets)) output := make([]*Puppet, len(dbPuppets))
@ -175,7 +191,7 @@ func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{ return &Puppet{
Puppet: dbPuppet, Puppet: dbPuppet,
bridge: br, bridge: br,
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), zlog: br.ZLog.With().Stringer("puppet_jid", dbPuppet.JID).Logger(),
MXID: br.FormatPuppetMXID(dbPuppet.JID), MXID: br.FormatPuppetMXID(dbPuppet.JID),
} }
@ -185,7 +201,7 @@ type Puppet struct {
*database.Puppet *database.Puppet
bridge *WABridge bridge *WABridge
log log.Logger zlog zerolog.Logger
typingIn id.RoomID typingIn id.RoomID
typingAt time.Time typingAt time.Time
@ -223,47 +239,47 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
return puppet.bridge.AS.Intent(puppet.MXID) return puppet.bridge.AS.Intent(puppet.MXID)
} }
func (puppet *Puppet) UpdateAvatar(source *User, forcePortalSync bool) bool { func (puppet *Puppet) UpdateAvatar(ctx context.Context, source *User, forcePortalSync bool) bool {
changed := source.updateAvatar(puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.log, puppet.DefaultIntent()) changed := source.updateAvatar(ctx, puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.DefaultIntent())
if !changed || puppet.Avatar == "unauthorized" { if !changed || puppet.Avatar == "unauthorized" {
if forcePortalSync { if forcePortalSync {
go puppet.updatePortalAvatar() go puppet.updatePortalAvatar(ctx)
} }
return changed return changed
} }
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) err := puppet.DefaultIntent().SetAvatarURL(ctx, puppet.AvatarURL)
if err != nil { if err != nil {
puppet.log.Warnln("Failed to set avatar:", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar from puppet")
} else { } else {
puppet.AvatarSet = true puppet.AvatarSet = true
} }
go puppet.updatePortalAvatar() go puppet.updatePortalAvatar(ctx)
return true return true
} }
func (puppet *Puppet) UpdateName(contact types.ContactInfo, forcePortalSync bool) bool { func (puppet *Puppet) UpdateName(ctx context.Context, contact types.ContactInfo, forcePortalSync bool) bool {
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact) newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact)
if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality { if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality {
oldName := puppet.Displayname oldName := puppet.Displayname
puppet.Displayname = newName puppet.Displayname = newName
puppet.NameQuality = quality puppet.NameQuality = quality
puppet.NameSet = false puppet.NameSet = false
err := puppet.DefaultIntent().SetDisplayName(newName) err := puppet.DefaultIntent().SetDisplayName(ctx, newName)
if err == nil { if err == nil {
puppet.log.Debugln("Updated name", oldName, "->", newName) puppet.zlog.Debug().Str("old_name", oldName).Str("new_name", newName).Msg("Updated name")
puppet.NameSet = true puppet.NameSet = true
go puppet.updatePortalName() go puppet.updatePortalName(ctx)
} else { } else {
puppet.log.Warnln("Failed to set display name:", err) puppet.zlog.Err(err).Msg("Failed to set displayname")
} }
return true return true
} else if forcePortalSync { } else if forcePortalSync {
go puppet.updatePortalName() go puppet.updatePortalName(ctx)
} }
return false return false
} }
func (puppet *Puppet) UpdateContactInfo() bool { func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool {
if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
return false return false
} }
@ -281,9 +297,9 @@ func (puppet *Puppet) UpdateContactInfo() bool {
"com.beeper.bridge.service": "whatsapp", "com.beeper.bridge.service": "whatsapp",
"com.beeper.bridge.network": "whatsapp", "com.beeper.bridge.network": "whatsapp",
} }
err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo) err := puppet.DefaultIntent().BeeperUpdateProfile(ctx, contactInfo)
if err != nil { if err != nil {
puppet.log.Warnln("Failed to store custom contact info in profile:", err) puppet.zlog.Err(err).Msg("Failed to store custom contact info in profile")
return false return false
} else { } else {
puppet.ContactInfoSet = true puppet.ContactInfoSet = true
@ -300,7 +316,7 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
} }
} }
func (puppet *Puppet) updatePortalAvatar() { func (puppet *Puppet) updatePortalAvatar(ctx context.Context) {
puppet.updatePortalMeta(func(portal *Portal) { puppet.updatePortalMeta(func(portal *Portal) {
if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) { if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) {
return return
@ -308,28 +324,31 @@ func (puppet *Puppet) updatePortalAvatar() {
portal.AvatarURL = puppet.AvatarURL portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar portal.Avatar = puppet.Avatar
portal.AvatarSet = false portal.AvatarSet = false
defer portal.Update(nil)
if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() { if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() {
portal.UpdateBridgeInfo() portal.UpdateBridgeInfo(ctx)
} else if len(portal.MXID) > 0 { } else if len(portal.MXID) > 0 {
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL) _, err := portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, puppet.AvatarURL)
if err != nil { if err != nil {
portal.log.Warnln("Failed to set avatar:", err) portal.zlog.Err(err).Msg("Failed to set avatar from puppet")
} else { } else {
portal.AvatarSet = true portal.AvatarSet = true
portal.UpdateBridgeInfo() portal.UpdateBridgeInfo(ctx)
} }
} }
err := portal.Update(ctx)
if err != nil {
portal.zlog.Err(err).Msg("Failed to save portal after updating avatar from puppet")
}
}) })
} }
func (puppet *Puppet) updatePortalName() { func (puppet *Puppet) updatePortalName(ctx context.Context) {
puppet.updatePortalMeta(func(portal *Portal) { puppet.updatePortalMeta(func(portal *Portal) {
portal.UpdateName(puppet.Displayname, types.EmptyJID, true) portal.UpdateName(ctx, puppet.Displayname, types.EmptyJID, true)
}) })
} }
func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName bool, reason string) { func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
if puppet == nil { if puppet == nil {
return return
} }
@ -337,39 +356,67 @@ func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName
source.EnqueuePuppetResync(puppet) source.EnqueuePuppetResync(puppet)
return return
} }
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.SyncContact").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
contact, err := source.Client.Store.Contacts.GetContact(puppet.JID) contact, err := source.Client.Store.Contacts.GetContact(puppet.JID)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to get contact info through %s in SyncContact: %v (sync reason: %s)", source.MXID, reason) log.Err(err).
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("Failed to get contact info through user in SyncContact")
} else if !contact.Found { } else if !contact.Found {
puppet.log.Warnfln("No contact info found through %s in SyncContact (sync reason: %s)", source.MXID, reason) log.Warn().
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("No contact info found through user in SyncContact")
} }
puppet.Sync(source, &contact, false, false) puppet.syncInternal(ctx, source, &contact, false, false)
} }
func (puppet *Puppet) Sync(source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) { func (puppet *Puppet) Sync(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.Sync").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
puppet.syncInternal(ctx, source, contact, forceAvatarSync, forcePortalSync)
}
func (puppet *Puppet) syncInternal(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx)
puppet.syncLock.Lock() puppet.syncLock.Lock()
defer puppet.syncLock.Unlock() defer puppet.syncLock.Unlock()
err := puppet.DefaultIntent().EnsureRegistered() err := puppet.DefaultIntent().EnsureRegistered(ctx)
if err != nil { if err != nil {
puppet.log.Errorln("Failed to ensure registered:", err) log.Err(err).Msg("Failed to ensure registered")
} }
puppet.log.Debugfln("Syncing info through %s", source.JID) log.Debug().Stringer("source_jid", source.JID).Msg("Syncing info through user")
update := false update := false
if contact != nil { if contact != nil {
if puppet.JID.User == source.JID.User { if puppet.JID.User == source.JID.User {
contact.PushName = source.Client.Store.PushName contact.PushName = source.Client.Store.PushName
} }
update = puppet.UpdateName(*contact, forcePortalSync) || update update = puppet.UpdateName(ctx, *contact, forcePortalSync) || update
} }
if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync { if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync {
update = puppet.UpdateAvatar(source, forcePortalSync) || update update = puppet.UpdateAvatar(ctx, source, forcePortalSync) || update
} }
update = puppet.UpdateContactInfo() || update update = puppet.UpdateContactInfo(ctx) || update
if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) { if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) {
puppet.LastSync = time.Now() puppet.LastSync = time.Now()
puppet.Update() err = puppet.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save puppet after sync")
}
} }
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan // Copyright (C) 2024 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -19,7 +19,6 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"image" "image"
"net/http" "net/http"
"net/url" "net/url"
@ -27,33 +26,26 @@ import (
"strings" "strings"
"time" "time"
"github.com/tidwall/gjson" "github.com/rs/zerolog"
"golang.org/x/net/idna" "golang.org/x/net/idna"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
) )
type BeeperLinkPreview struct { func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*event.BeeperLinkPreview {
mautrix.RespPreviewURL
MatchedURL string `json:"matched_url"`
ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
}
func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*BeeperLinkPreview {
if msg.GetMatchedText() == "" { if msg.GetMatchedText() == "" {
return []*BeeperLinkPreview{} return []*event.BeeperLinkPreview{}
} }
output := &BeeperLinkPreview{ output := &event.BeeperLinkPreview{
MatchedURL: msg.GetMatchedText(), MatchedURL: msg.GetMatchedText(),
RespPreviewURL: mautrix.RespPreviewURL{ LinkPreview: event.LinkPreview{
CanonicalURL: msg.GetCanonicalUrl(), CanonicalURL: msg.GetCanonicalUrl(),
Title: msg.GetTitle(), Title: msg.GetTitle(),
Description: msg.GetDescription(), Description: msg.GetDescription(),
@ -68,7 +60,7 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
var err error var err error
thumbnailData, err = source.Client.DownloadThumbnail(msg) thumbnailData, err = source.Client.DownloadThumbnail(msg)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to download thumbnail for link preview: %v", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to download thumbnail for link preview")
} }
} }
if thumbnailData == nil && msg.JpegThumbnail != nil { if thumbnailData == nil && msg.JpegThumbnail != nil {
@ -93,9 +85,9 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
uploadMime = "application/octet-stream" uploadMime = "application/octet-stream"
output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto} output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto}
} }
resp, err := intent.UploadBytes(uploadData, uploadMime) resp, err := intent.UploadBytes(ctx, uploadData, uploadMime)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to reupload thumbnail for link preview: %v", err) zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload thumbnail for link preview")
} else { } else {
if output.ImageEncryption != nil { if output.ImageEncryption != nil {
output.ImageEncryption.URL = resp.ContentURI.CUString() output.ImageEncryption.URL = resp.ContentURI.CUString()
@ -108,36 +100,37 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
output.Type = "video.other" output.Type = "video.other"
} }
return []*BeeperLinkPreview{output} return []*event.BeeperLinkPreview{output}
} }
var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`) var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`)
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, evt *event.Event, dest *waProto.ExtendedTextMessage) bool { func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, content *event.MessageEventContent, dest *waProto.ExtendedTextMessage) bool {
var preview *BeeperLinkPreview log := zerolog.Ctx(ctx)
var preview *event.BeeperLinkPreview
rawPreview := gjson.GetBytes(evt.Content.VeryRaw, `com\.beeper\.linkpreviews`) if content.BeeperLinkPreviews != nil {
if rawPreview.Exists() && rawPreview.IsArray() { // Note: this check explicitly happens after checking for nil: empty arrays are treated as no previews,
var previews []BeeperLinkPreview // but omitting the field means the bridge may look for URLs in the message text.
if err := json.Unmarshal([]byte(rawPreview.Raw), &previews); err != nil || len(previews) == 0 { if len(content.BeeperLinkPreviews) == 0 {
return false return false
} }
// WhatsApp only supports a single preview. // WhatsApp only supports a single preview.
preview = &previews[0] preview = content.BeeperLinkPreviews[0]
} else if portal.bridge.Config.Bridge.URLPreviews { } else if portal.bridge.Config.Bridge.URLPreviews {
if matchedURL := URLRegex.FindString(evt.Content.AsMessage().Body); len(matchedURL) == 0 { if matchedURL := URLRegex.FindString(content.Body); len(matchedURL) == 0 {
return false return false
} else if parsed, err := url.Parse(matchedURL); err != nil { } else if parsed, err := url.Parse(matchedURL); err != nil {
return false return false
} else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil { } else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
return false return false
} else if mxPreview, err := portal.MainIntent().GetURLPreview(parsed.String()); err != nil { } else if mxPreview, err := portal.MainIntent().GetURLPreview(ctx, parsed.String()); err != nil {
portal.log.Warnfln("Failed to fetch preview for %s: %v", matchedURL, err) log.Err(err).Str("url", matchedURL).Msg("Failed to fetch URL preview")
return false return false
} else { } else {
preview = &BeeperLinkPreview{ preview = &event.BeeperLinkPreview{
RespPreviewURL: *mxPreview, LinkPreview: *mxPreview,
MatchedURL: matchedURL, MatchedURL: matchedURL,
} }
} }
} }
@ -163,22 +156,22 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
imageMXC = preview.ImageEncryption.URL.ParseOrIgnore() imageMXC = preview.ImageEncryption.URL.ParseOrIgnore()
} }
if !imageMXC.IsEmpty() { if !imageMXC.IsEmpty() {
data, err := portal.MainIntent().DownloadBytesContext(ctx, imageMXC) data, err := portal.MainIntent().DownloadBytes(ctx, imageMXC)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to download URL preview image %s in %s: %v", preview.ImageURL, evt.ID, err) log.Err(err).Str("image_url", string(preview.ImageURL)).Msg("Failed to download URL preview image")
return true return true
} }
if preview.ImageEncryption != nil { if preview.ImageEncryption != nil {
err = preview.ImageEncryption.DecryptInPlace(data) err = preview.ImageEncryption.DecryptInPlace(data)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to decrypt URL preview image in %s: %v", evt.ID, err) log.Err(err).Msg("Failed to decrypt URL preview image")
return true return true
} }
} }
dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix()) dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix())
uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail) uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to upload URL preview thumbnail in %s: %v", evt.ID, err) log.Err(err).Msg("Failed to reupload URL preview thumbnail")
return true return true
} }
dest.ThumbnailSha256 = uploadResp.FileSHA256 dest.ThumbnailSha256 = uploadResp.FileSHA256
@ -188,7 +181,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
var width, height int var width, height int
dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false) dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err) log.Err(err).Msg("Failed to create JPEG thumbnail for URL preview")
} }
if preview.ImageHeight > 0 && preview.ImageWidth > 0 { if preview.ImageHeight > 0 && preview.ImageWidth > 0 {
dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth)) dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth))

526
user.go

File diff suppressed because it is too large Load diff