Update dependencies and lots of code

* Bump minimum Go version to 1.21
* Add contexts everywhere
* Switch database code to new dbutil patterns
* Finish switching away from maulogger
This commit is contained in:
Tulir Asokan 2024-03-11 22:27:10 +02:00
parent f8a22aab06
commit 103bfc31c6
37 changed files with 3423 additions and 2918 deletions

View File

@ -8,7 +8,8 @@ jobs:
strategy:
fail-fast: false
matrix:
go-version: ["1.20", "1.21"]
go-version: ["1.21", "1.22"]
name: Lint ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }}
steps:
- uses: actions/checkout@v4

View File

@ -15,6 +15,6 @@ repos:
- id: go-vet-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.2.2
rev: v0.3.1
hooks:
- id: zerolog-ban-msgf

View File

@ -1,3 +1,9 @@
# v0.10.6 (unreleased)
* Bumped minimum Go version to 1.21.
* Added 8-letter code pairing support to provisioning API.
* Added more bugs to fix later.
# v0.10.5 (2023-12-16)
* Added support for sending media to channels.

View File

@ -22,7 +22,7 @@ import (
"fmt"
"net/http"
log "maunium.net/go/maulogger/v2"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/id"
)
@ -30,7 +30,7 @@ type AnalyticsClient struct {
url string
key string
userID string
log log.Logger
log zerolog.Logger
client http.Client
}
@ -88,9 +88,9 @@ func (sc *AnalyticsClient) Track(userID id.UserID, event string, properties ...m
props["bridge"] = "whatsapp"
err := sc.trackSync(userID, event, props)
if err != nil {
sc.log.Errorfln("Error tracking %s: %v", event, err)
sc.log.Err(err).Str("event", event).Msg("Error tracking event")
} else {
sc.log.Debugln("Tracked", event)
sc.log.Debug().Str("event", event).Msg("Tracked event")
}
}()
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,22 +17,21 @@
package main
import (
"context"
"time"
log "maunium.net/go/maulogger/v2"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database"
)
type BackfillQueue struct {
BackfillQuery *database.BackfillQuery
BackfillQuery *database.BackfillTaskQuery
reCheckChannels []chan bool
log log.Logger
}
func (bq *BackfillQueue) ReCheck() {
bq.log.Infofln("Sending re-checks to %d channels", len(bq.reCheckChannels))
for _, channel := range bq.reCheckChannels {
go func(c chan bool) {
c <- true
@ -40,12 +39,19 @@ func (bq *BackfillQueue) ReCheck() {
}
}
func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill {
func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.BackfillTask {
for {
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(userID, waitForBackfillTypes) {
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(ctx, userID, waitForBackfillTypes) {
// check for immediate when dealing with deferred
if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil {
backfill.MarkDispatched()
if backfill, err := bq.BackfillQuery.GetNext(ctx, userID, backfillTypes); err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get next backfill task")
} else if backfill != nil {
err = backfill.MarkDispatched(ctx)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Int("queue_id", backfill.QueueID).
Msg("Failed to mark backfill task as dispatched")
}
return backfill
}
}
@ -58,38 +64,73 @@ func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []datab
}
func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) {
log := user.zlog.With().
Str("action", "backfill request loop").
Any("types", backfillTypes).
Logger()
ctx := log.WithContext(context.TODO())
reCheckChannel := make(chan bool)
user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel)
for {
req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
user.log.Infofln("Handling backfill request %s", req)
req := user.BackfillQueue.GetNextBackfill(ctx, user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
log.Info().Any("backfill_request", req).Msg("Handling backfill request")
log := log.With().
Int("queue_id", req.QueueID).
Stringer("portal_jid", req.Portal.JID).
Logger()
ctx := log.WithContext(ctx)
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, *req.Portal)
if conv == nil {
user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String())
req.MarkDone()
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, req.Portal)
if err != nil {
log.Err(err).Msg("Failed to get conversation data for backfill request")
continue
} else if conv == nil {
log.Debug().Msg("Couldn't find conversation data for backfill request")
err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after data was not found")
}
continue
}
portal := user.GetPortalByJID(conv.PortalKey.JID)
// Update the client store with basic chat settings.
if conv.MuteEndTime.After(time.Now()) {
user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
err = user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
if err != nil {
log.Err(err).Msg("Failed to save muted until time from conversation data")
}
}
if conv.Archived {
user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
err = user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save archived state from conversation data")
}
}
if conv.Pinned > 0 {
user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
err = user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save pinned state from conversation data")
}
}
if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
log.Debug().
Uint32("old_time", portal.ExpirationTime).
Uint32("new_time", *conv.EphemeralExpiration).
Msg("Updating portal ephemeral expiration time")
portal.ExpirationTime = *conv.EphemeralExpiration
portal.Update(nil)
err = portal.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal after updating expiration time")
}
}
user.backfillInChunks(req, conv, portal)
req.MarkDone()
user.backfillInChunks(ctx, req, conv, portal)
err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after backfilling")
}
}
}

View File

@ -93,7 +93,7 @@ func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Requ
remote = remote.Fill(user)
resp.RemoteStates[remote.RemoteID] = remote
}
user.log.Debugfln("Responding bridge state in bridge status endpoint: %+v", resp)
user.zlog.Debug().Any("response_data", &resp).Msg("Responding bridge state in bridge status endpoint")
jsonResponse(w, http.StatusOK, &resp)
if len(resp.RemoteStates) > 0 {
user.BridgeState.SetPrev(remote)

View File

@ -29,6 +29,7 @@ import (
"strings"
"time"
"github.com/rs/zerolog"
"github.com/skip2/go-qrcode"
"github.com/tidwall/gjson"
@ -37,7 +38,6 @@ import (
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/event"
@ -117,7 +117,10 @@ func fnSetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ce.User.MXID
ce.Portal.Update(nil)
err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting relay user")
}
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
}
}
@ -139,7 +142,10 @@ func fnUnsetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ""
ce.Portal.Update(nil)
err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after clearing relay user")
}
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
}
}
@ -246,7 +252,7 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to join group: %v", err)
return
}
ce.Log.Debugln("%s successfully joined group %s", ce.User.MXID, jid)
ce.ZLog.Debug().Stringer("group_jid", jid).Msg("User successfully joined WhatsApp group with link")
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
} else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) {
info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0])
@ -259,14 +265,14 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to follow channel: %v", err)
return
}
ce.Log.Debugln("%s successfully followed channel %s", ce.User.MXID, info.ID)
ce.ZLog.Debug().Stringer("channel_jid", info.ID).Msg("User successfully followed WhatsApp channel with link")
ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID)
} else {
ce.Reply("That doesn't look like a WhatsApp invite link")
}
}
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage, error) {
var data json.RawMessage
if evt.Type != event.EventEncrypted {
data = evt.Content.VeryRaw
@ -275,7 +281,7 @@ func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, e
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, err
}
decrypted, err := crypto.Decrypt(evt)
decrypted, err := ce.Bridge.Crypto.Decrypt(ce.Ctx, evt)
if err != nil {
return nil, err
}
@ -311,11 +317,11 @@ var cmdAccept = &commands.FullHandler{
func fnAccept(ce *WrappedCommandEvent) {
if len(ce.ReplyTo) == 0 {
ce.Reply("You must reply to a group invite message when using this command.")
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.RoomID, ce.ReplyTo); err != nil {
ce.Log.Errorln("Failed to get event %s to handle !wa accept command: %v", ce.ReplyTo, err)
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.Ctx, ce.RoomID, ce.ReplyTo); err != nil {
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to get reply target event to handle !wa accept command")
ce.Reply("Failed to get reply event")
} else if rawContent, err := tryDecryptEvent(ce.Bridge.Crypto, evt); err != nil {
ce.Log.Errorln("Failed to decrypt event %s to handle !wa accept command: %v", ce.ReplyTo, err)
} else if rawContent, err := tryDecryptEvent(ce, evt); err != nil {
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to decrypt reply target event to handle !wa accept command")
ce.Reply("Failed to decrypt reply event")
} else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil {
ce.Reply("That doesn't look like a group invite message.")
@ -344,16 +350,16 @@ func fnCreate(ce *WrappedCommandEvent) {
return
}
members, err := ce.Bot.JoinedMembers(ce.RoomID)
members, err := ce.Bot.JoinedMembers(ce.Ctx, ce.RoomID)
if err != nil {
ce.Reply("Failed to get room members: %v", err)
return
}
var roomNameEvent event.RoomNameEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateRoomName, "", &roomNameEvent)
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateRoomName, "", &roomNameEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.Log.Errorln("Failed to get room name to create group:", err)
ce.ZLog.Err(err).Msg("Failed to get room name to create group")
ce.Reply("Failed to get room name")
return
} else if len(roomNameEvent.Name) == 0 {
@ -362,15 +368,17 @@ func fnCreate(ce *WrappedCommandEvent) {
}
var encryptionEvent event.EncryptionEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateEncryption, "", &encryptionEvent)
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateEncryption, "", &encryptionEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room encryption status to create group")
ce.Reply("Failed to get room encryption status")
return
}
var createEvent event.CreateEventContent
err = ce.Bot.StateEvent(ce.RoomID, event.StateCreate, "", &createEvent)
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateCreate, "", &createEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room create event to create group")
ce.Reply("Failed to get room create event")
return
}
@ -395,7 +403,11 @@ func fnCreate(ce *WrappedCommandEvent) {
// TODO check m.space.parent to create rooms directly in communities
messageID := ce.User.Client.GenerateMessageID()
ce.Log.Infofln("Creating group for %s with name %s and participants %+v (create key: %s)", ce.RoomID, roomNameEvent.Name, participants, messageID)
ce.ZLog.Info().
Str("room_name", roomNameEvent.Name).
Any("participants", participants).
Str("create_key", messageID).
Msg("Creating WhatsApp group for Matrix room")
ce.User.createKeyDedup = messageID
resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{
CreateKey: messageID,
@ -409,21 +421,25 @@ func fnCreate(ce *WrappedCommandEvent) {
ce.Reply("Failed to create group: %v", err)
return
}
ce.ZLog.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("group_jid", resp.JID.String())
})
portal := ce.User.GetPortalByJID(resp.JID)
portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock()
if len(portal.MXID) != 0 {
portal.log.Warnln("Detected race condition in room creation")
ce.ZLog.Warn().Msg("Detected race condition in room creation")
// TODO race condition, clean up the old room
}
portal.MXID = ce.RoomID
portal.updateLogger()
portal.Name = roomNameEvent.Name
portal.IsParent = resp.IsParent
portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1
if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default {
_, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
_, err = portal.MainIntent().SendStateEvent(ce.Ctx, portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil {
portal.log.Warnln("Failed to enable encryption in room:", err)
ce.ZLog.Err(err).Msg("Failed to enable encryption in room")
if errors.Is(err, mautrix.MForbidden) {
ce.Reply("I don't seem to have permission to enable encryption in this room.")
} else {
@ -433,8 +449,11 @@ func fnCreate(ce *WrappedCommandEvent) {
portal.Encrypted = true
}
portal.Update(nil)
portal.UpdateBridgeInfo()
err = portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after creating group")
}
portal.UpdateBridgeInfo(ce.Ctx)
ce.User.createKeyDedup = ""
ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)
@ -512,7 +531,7 @@ func fnLogin(ce *WrappedCommandEvent) {
}
}
if qrEventID != "" {
_, _ = ce.Bot.RedactEvent(ce.RoomID, qrEventID)
_, _ = ce.Bot.RedactEvent(ce.Ctx, ce.RoomID, qrEventID)
}
}
@ -529,9 +548,9 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
if len(prevEvent) != 0 {
content.SetEdit(prevEvent)
}
resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content)
resp, err := ce.Bot.SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, &content)
if err != nil {
user.log.Errorln("Failed to send edited QR code to user:", err)
ce.ZLog.Err(err).Msg("Failed to send edited QR code to user")
} else if len(prevEvent) == 0 {
prevEvent = resp.EventID
}
@ -541,16 +560,16 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) {
qrCode, err := qrcode.Encode(code, qrcode.Low, 256)
if err != nil {
user.log.Errorln("Failed to encode QR code:", err)
ce.ZLog.Err(err).Msg("Failed to encode QR code")
ce.Reply("Failed to encode QR code: %v", err)
return id.ContentURI{}, false
}
bot := user.bridge.AS.BotClient()
resp, err := bot.UploadBytes(qrCode, "image/png")
resp, err := bot.UploadBytes(ce.Ctx, qrCode, "image/png")
if err != nil {
user.log.Errorln("Failed to upload QR code:", err)
ce.ZLog.Err(err).Msg("Failed to upload QR code")
ce.Reply("Failed to upload QR code: %v", err)
return id.ContentURI{}, false
}
@ -578,14 +597,14 @@ func fnLogout(ce *WrappedCommandEvent) {
puppet.ClearCustomMXID()
err := ce.User.Client.Logout()
if err != nil {
ce.User.log.Warnln("Error while logging out:", err)
ce.ZLog.Err(err).Msg("Unknown error while logging out")
ce.Reply("Unknown error while logging out: %v", err)
return
}
ce.User.Session = nil
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection()
ce.User.DeleteSession()
ce.User.DeleteSession(ce.Ctx)
ce.Reply("Logged out successfully.")
}
@ -620,10 +639,13 @@ func fnTogglePresence(ce *WrappedCommandEvent) {
if ce.User.IsLoggedIn() {
err := ce.User.Client.SendPresence(newPresence)
if err != nil {
ce.User.log.Warnln("Failed to set presence:", err)
ce.ZLog.Err(err).Msg("Failed to send presence to WhatsApp")
}
}
customPuppet.Update()
err := customPuppet.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save puppet after toggling presence")
}
}
var cmdDeleteSession = &commands.FullHandler{
@ -642,7 +664,7 @@ func fnDeleteSession(ce *WrappedCommandEvent) {
}
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection()
ce.User.DeleteSession()
ce.User.DeleteSession(ce.Ctx)
ce.Reply("Session information purged")
}
@ -716,19 +738,19 @@ func fnPing(ce *WrappedCommandEvent) {
}
}
func canDeletePortal(portal *Portal, userID id.UserID) bool {
func canDeletePortal(ce *WrappedCommandEvent, portal *Portal) bool {
if len(portal.MXID) == 0 {
return false
}
members, err := portal.MainIntent().JoinedMembers(portal.MXID)
members, err := portal.MainIntent().JoinedMembers(ce.Ctx, portal.MXID)
if err != nil {
portal.log.Errorfln("Failed to get joined members to check if portal can be deleted by %s: %v", userID, err)
ce.ZLog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get joined members to check if portal can be deleted by user")
return false
}
for otherUser := range members.Joined {
_, isPuppet := portal.bridge.ParsePuppetMXID(otherUser)
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == userID {
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == ce.User.MXID {
continue
}
user := portal.bridge.GetUserByMXID(otherUser)
@ -750,14 +772,14 @@ var cmdDeletePortal = &commands.FullHandler{
}
func fnDeletePortal(ce *WrappedCommandEvent) {
if !ce.User.Admin && !canDeletePortal(ce.Portal, ce.User.MXID) {
if !ce.User.Admin && !canDeletePortal(ce, ce.Portal) {
ce.Reply("Only bridge admins can delete portals with other Matrix users")
return
}
ce.Portal.log.Infoln(ce.User.MXID, "requested deletion of portal.")
ce.Portal.Delete()
ce.Portal.Cleanup(false)
ce.ZLog.Info().Msg("User requested deletion of current portal")
ce.Portal.Delete(ce.Ctx)
ce.Portal.Cleanup(ce.Ctx, false)
}
var cmdDeleteAllPortals = &commands.FullHandler{
@ -778,7 +800,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
} else {
portalsToDelete = portals[:0]
for _, portal := range portals {
if canDeletePortal(portal, ce.User.MXID) {
if canDeletePortal(ce, portal) {
portalsToDelete = append(portalsToDelete, portal)
}
}
@ -790,7 +812,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
leave := func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{
_, _ = portal.MainIntent().KickUser(ce.Ctx, portal.MXID, &mautrix.ReqKickUser{
Reason: "Deleting portal",
UserID: ce.User.MXID,
})
@ -801,21 +823,21 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
intent := customPuppet.CustomIntent()
leave = func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = intent.LeaveRoom(portal.MXID)
_, _ = intent.ForgetRoom(portal.MXID)
_, _ = intent.LeaveRoom(ce.Ctx, portal.MXID)
_, _ = intent.ForgetRoom(ce.Ctx, portal.MXID)
}
}
}
ce.Reply("Found %d portals, deleting...", len(portalsToDelete))
for _, portal := range portalsToDelete {
portal.Delete()
portal.Delete(ce.Ctx)
leave(portal)
}
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
go func() {
for _, portal := range portalsToDelete {
portal.Cleanup(false)
portal.Cleanup(ce.Ctx, false)
}
ce.Reply("Finished background cleanup of deleted portal rooms.")
}()
@ -882,7 +904,7 @@ func fnList(ce *WrappedCommandEvent) {
}
var err error
page := 1
max := 100
maxPerPage := 100
if len(ce.Args) > 1 {
page, err = strconv.Atoi(ce.Args[1])
if err != nil || page <= 0 {
@ -891,11 +913,11 @@ func fnList(ce *WrappedCommandEvent) {
}
}
if len(ce.Args) > 2 {
max, err = strconv.Atoi(ce.Args[2])
if err != nil || max <= 0 {
maxPerPage, err = strconv.Atoi(ce.Args[2])
if err != nil || maxPerPage <= 0 {
ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2])
return
} else if max > 400 {
} else if maxPerPage > 400 {
ce.Reply("Warning: a high number of items per page may fail to send a reply")
}
}
@ -924,8 +946,8 @@ func fnList(ce *WrappedCommandEvent) {
ce.Reply("No %s found", strings.ToLower(typeName))
return
}
pages := int(math.Ceil(float64(len(result)) / float64(max)))
if (page-1)*max >= len(result) {
pages := int(math.Ceil(float64(len(result)) / float64(maxPerPage)))
if (page-1)*maxPerPage >= len(result) {
if pages == 1 {
ce.Reply("There is only 1 page of %s", strings.ToLower(typeName))
} else {
@ -933,11 +955,11 @@ func fnList(ce *WrappedCommandEvent) {
}
return
}
lastIndex := page * max
lastIndex := page * maxPerPage
if lastIndex > len(result) {
lastIndex = len(result)
}
result = result[(page-1)*max : lastIndex]
result = result[(page-1)*maxPerPage : lastIndex]
ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n"))
}
@ -1036,13 +1058,13 @@ func fnOpen(ce *WrappedCommandEvent) {
}
jid = newsletterMetadata.ID
}
ce.Log.Debugln("Importing", jid, "for", ce.User.MXID)
ce.ZLog.Debug().Stringer("chat_jid", jid).Msg("Importing chat for user")
portal := ce.User.GetPortalByJID(jid)
if len(portal.MXID) > 0 {
portal.UpdateMatrixRoom(ce.User, groupInfo, newsletterMetadata)
portal.UpdateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata)
ce.Reply("Portal room synced.")
} else {
err = portal.CreateMatrixRoom(ce.User, groupInfo, newsletterMetadata, true, true)
err = portal.CreateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata, true, true)
if err != nil {
ce.Reply("Failed to create room: %v", err)
} else {
@ -1085,7 +1107,7 @@ func fnPM(ce *WrappedCommandEvent) {
return
}
portal, puppet, justCreated, err := user.StartPM(targetUser.JID, "manual PM command")
portal, puppet, justCreated, err := user.StartPM(ce.Ctx, targetUser.JID, "manual PM command")
if err != nil {
ce.Reply("Failed to create portal room: %v", err)
} else if !justCreated {
@ -1154,11 +1176,16 @@ func fnSync(ce *WrappedCommandEvent) {
ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge")
return
}
keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID)
keys, err := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.Ctx, ce.User.JID)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to get list of private chats not in space")
ce.Reply("Failed to get list of private chats not in space")
return
}
count := 0
for _, key := range keys {
portal := ce.Bridge.GetPortalByJID(key)
portal.addToPersonalSpace(ce.User)
portal.addToPersonalSpace(ce.Ctx, ce.User)
count++
}
plural := "s"
@ -1208,6 +1235,9 @@ func fnDisappearingTimer(ce *WrappedCommandEvent) {
ce.Portal.ExpirationTime = prevExpirationTime
return
}
ce.Portal.Update(nil)
err = ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting disappearing timer")
}
ce.React("✅")
}

View File

@ -17,6 +17,9 @@
package main
import (
"context"
"fmt"
"maunium.net/go/mautrix/id"
)
@ -24,8 +27,11 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error
puppet.CustomMXID = mxid
puppet.AccessToken = accessToken
puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence
puppet.Update()
err := puppet.StartCustomMXID(false)
err := puppet.Update(context.TODO())
if err != nil {
return fmt.Errorf("failed to save access token: %w", err)
}
err = puppet.StartCustomMXID(false)
if err != nil {
return err
}
@ -45,12 +51,15 @@ func (puppet *Puppet) ClearCustomMXID() {
puppet.customIntent = nil
puppet.customUser = nil
if save {
puppet.Update()
err := puppet.Update(context.TODO())
if err != nil {
puppet.zlog.Err(err).Msg("Failed to clear custom MXID")
}
}
}
func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
if err != nil {
puppet.ClearCustomMXID()
return err
@ -60,11 +69,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
puppet.bridge.puppetsLock.Unlock()
if puppet.AccessToken != newAccessToken {
puppet.AccessToken = newAccessToken
puppet.Update()
err = puppet.Update(context.TODO())
}
puppet.customIntent = newIntent
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
return nil
return err
}
func (user *User) tryAutomaticDoublePuppeting() {

View File

@ -1,340 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillQuery struct {
db *Database
log log.Logger
backfillQueryLock sync.Mutex
}
func (bq *BackfillQuery) New() *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
Portal: &PortalKey{},
}
}
func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillQuery = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightQuery = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
)
// GetNext returns the next backfill to perform
func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
var types []string
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID, time.Now().Add(-15*time.Minute))
if err != nil || rows == nil {
bq.log.Errorfln("Failed to query next backfill queue job: %v", err)
return
}
defer rows.Close()
if rows.Next() {
backfill = bq.New().Scan(rows)
}
return
}
func (bq *BackfillQuery) HasUnstartedOrInFlightOfType(userID id.UserID, backfillTypes []BackfillType) bool {
if len(backfillTypes) == 0 {
return false
}
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
types := []string{}
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getUnstartedOrInFlightQuery, strings.Join(types, ",")), userID)
if err != nil || rows == nil {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
bq.log.Warnfln("Failed to query backfill queue jobs: %v", err)
}
// No rows means that there are no unstarted or in flight backfill
// requests.
return false
}
defer rows.Close()
return rows.Next()
}
func (bq *BackfillQuery) DeleteAll(userID id.UserID) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s: %v", userID, err)
}
}
func (bq *BackfillQuery) DeleteAllForPortal(userID id.UserID, portalKey PortalKey) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec(`
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`, userID, portalKey.JID, portalKey.Receiver)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s/%s: %v", userID, portalKey.JID, err)
}
}
type Backfill struct {
db *Database
log log.Logger
// Fields
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal *PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *Backfill) String() string {
return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &maxTotalEvents, &batchDelay)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b
}
func (b *Backfill) Insert() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
rows, err := b.db.Query(`
INSERT INTO backfill_queue
(user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt)
defer rows.Close()
if err != nil || !rows.Next() {
b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
return
}
err = rows.Scan(&b.QueueID)
if err != nil {
b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
}
}
func (b *Backfill) MarkDispatched() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err)
}
}
func (b *Backfill) MarkDone() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as complete: %v", b.BackfillType, b.Priority, err)
}
}
func (bq *BackfillQuery) NewBackfillState(userID id.UserID, portalKey *PortalKey) *BackfillState {
return &BackfillState{
db: bq.db,
log: bq.log,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillState = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
)
type BackfillState struct {
db *Database
log log.Logger
// Fields
UserID id.UserID
Portal *PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
return b
}
func (b *BackfillState) Upsert() {
_, err := b.db.Exec(`
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts`,
b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp)
if err != nil {
b.log.Warnfln("Failed to insert backfill state for %s: %v", b.Portal.JID, err)
}
}
func (b *BackfillState) SetProcessingBatch(processing bool) {
b.ProcessingBatch = processing
b.Upsert()
}
func (bq *BackfillQuery) GetBackfillState(userID id.UserID, portalKey *PortalKey) (backfillState *BackfillState) {
rows, err := bq.db.Query(getBackfillState, userID, portalKey.JID, portalKey.Receiver)
if err != nil || rows == nil {
bq.log.Error(err)
return
}
defer rows.Close()
if rows.Next() {
backfillState = bq.NewBackfillState(userID, portalKey).Scan(rows)
}
return
}

253
database/backfillqueue.go Normal file
View File

@ -0,0 +1,253 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillTaskQuery struct {
*dbutil.QueryHelper[*BackfillTask]
//backfillQueryLock sync.Mutex
}
func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask {
return &BackfillTask{qh: qh}
}
func (bq *BackfillTaskQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *BackfillTask {
return &BackfillTask{
qh: bq.QueryHelper,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillTaskQueryTemplate = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightBackfillTaskQueryTemplate = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
deleteBackfillQueueForUserQuery = "DELETE FROM backfill_queue WHERE user_mxid=$1"
deleteBackfillQueueForPortalQuery = `
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
insertBackfillTaskQuery = `
INSERT INTO backfill_queue (
user_mxid, type, priority, portal_jid, portal_receiver, time_start,
max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`
markBackfillTaskDispatchedQuery = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
markBackfillTaskDoneQuery = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
)
func typesToString(backfillTypes []BackfillType) string {
types := make([]string, len(backfillTypes))
for i, backfillType := range backfillTypes {
types[i] = strconv.Itoa(int(backfillType))
}
return strings.Join(types, ",")
}
// GetNext returns the next backfill to perform
func (bq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (*BackfillTask, error) {
if len(backfillTypes) == 0 {
return nil, nil
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getNextBackfillTaskQueryTemplate, typesToString(backfillTypes))
return bq.QueryOne(ctx, query, userID, time.Now().Add(-15*time.Minute))
}
func (bq *BackfillTaskQuery) HasUnstartedOrInFlightOfType(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (has bool) {
if len(backfillTypes) == 0 {
return false
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getUnstartedOrInFlightBackfillTaskQueryTemplate, typesToString(backfillTypes))
err := bq.GetDB().QueryRow(ctx, query, userID).Scan(&has)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).Msg("Failed to check if backfill queue has jobs")
}
return
}
func (bq *BackfillTaskQuery) DeleteAll(ctx context.Context, userID id.UserID) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForUserQuery, userID)
}
func (bq *BackfillTaskQuery) DeleteAllForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForPortalQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillTask struct {
qh *dbutil.QueryHelper[*BackfillTask]
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *BackfillTask) MarshalZerologObject(evt *zerolog.Event) {
evt.Int("queue_id", b.QueueID).
Stringer("user_id", b.UserID).
Stringer("backfill_type", b.BackfillType).
Int("priority", b.Priority).
Stringer("portal_jid", b.Portal.JID).
Any("time_start", b.TimeStart).
Int("max_batch_events", b.MaxBatchEvents).
Int("max_total_events", b.MaxTotalEvents).
Int("batch_delay", b.BatchDelay).
Any("dispatch_time", b.DispatchTime).
Any("completed_at", b.CompletedAt)
}
func (b *BackfillTask) String() string {
return fmt.Sprintf(
"BackfillTask{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(
&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart,
&b.MaxBatchEvents, &maxTotalEvents, &batchDelay,
)
if err != nil {
return nil, err
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b, nil
}
func (b *BackfillTask) sqlVariables() []any {
return []any{
b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart,
b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt,
}
}
func (b *BackfillTask) Insert(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
return b.qh.GetDB().QueryRow(ctx, insertBackfillTaskQuery, b.sqlVariables()...).Scan(&b.QueueID)
}
func (b *BackfillTask) MarkDispatched(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDispatchedQuery, time.Now(), b.QueueID)
}
func (b *BackfillTask) MarkDone(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDoneQuery, time.Now(), b.QueueID)
}

94
database/backfillstate.go Normal file
View File

@ -0,0 +1,94 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillStateQuery struct {
*dbutil.QueryHelper[*BackfillState]
}
func newBackfillState(qh *dbutil.QueryHelper[*BackfillState]) *BackfillState {
return &BackfillState{qh: qh}
}
func (bq *BackfillStateQuery) NewBackfillState(userID id.UserID, portalKey PortalKey) *BackfillState {
return &BackfillState{
qh: bq.QueryHelper,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillStateQuery = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
upsertBackfillStateQuery = `
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts
`
)
func (bq *BackfillStateQuery) GetBackfillState(ctx context.Context, userID id.UserID, portalKey PortalKey) (*BackfillState, error) {
return bq.QueryOne(ctx, getBackfillStateQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillState struct {
qh *dbutil.QueryHelper[*BackfillState]
UserID id.UserID
Portal PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) (*BackfillState, error) {
return dbutil.ValueOrErr(b, row.Scan(
&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp,
))
}
func (b *BackfillState) sqlVariables() []any {
return []any{b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp}
}
func (b *BackfillState) Upsert(ctx context.Context) error {
return b.qh.Exec(ctx, upsertBackfillStateQuery, b.sqlVariables()...)
}
func (b *BackfillState) SetProcessingBatch(ctx context.Context, processing bool) error {
b.ProcessingBatch = processing
return b.Upsert(ctx)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -26,7 +26,6 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/database/upgrades"
)
@ -45,51 +44,28 @@ type Database struct {
Reaction *ReactionQuery
DisappearingMessage *DisappearingMessageQuery
Backfill *BackfillQuery
BackfillQueue *BackfillTaskQuery
BackfillState *BackfillStateQuery
HistorySync *HistorySyncQuery
MediaBackfillRequest *MediaBackfillRequestQuery
}
func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
db := &Database{Database: baseDB}
func New(db *dbutil.Database) *Database {
db.UpgradeTable = upgrades.Table
db.User = &UserQuery{
db: db,
log: log.Sub("User"),
return &Database{
Database: db,
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)},
BackfillQueue: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)},
BackfillState: &BackfillStateQuery{dbutil.MakeQueryHelper(db, newBackfillState)},
HistorySync: &HistorySyncQuery{dbutil.MakeQueryHelper(db, newHistorySyncConversation)},
MediaBackfillRequest: &MediaBackfillRequestQuery{dbutil.MakeQueryHelper(db, newMediaBackfillRequest)},
}
db.Portal = &PortalQuery{
db: db,
log: log.Sub("Portal"),
}
db.Puppet = &PuppetQuery{
db: db,
log: log.Sub("Puppet"),
}
db.Message = &MessageQuery{
db: db,
log: log.Sub("Message"),
}
db.Reaction = &ReactionQuery{
db: db,
log: log.Sub("Reaction"),
}
db.DisappearingMessage = &DisappearingMessageQuery{
db: db,
log: log.Sub("DisappearingMessage"),
}
db.Backfill = &BackfillQuery{
db: db,
log: log.Sub("Backfill"),
}
db.HistorySync = &HistorySyncQuery{
db: db,
log: log.Sub("HistorySync"),
}
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
db: db,
log: log.Sub("MediaBackfillRequest"),
}
return db
}
func isRetryableError(err error) bool {

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,32 +17,29 @@
package database
import (
"context"
"database/sql"
"errors"
"time"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
)
type DisappearingMessageQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*DisappearingMessage]
}
func (dmq *DisappearingMessageQuery) New() *DisappearingMessage {
func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
return &DisappearingMessage{
db: dmq.db,
log: dmq.log,
qh: qh,
}
}
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
dm := &DisappearingMessage{
db: dmq.db,
log: dmq.log,
qh: dmq.QueryHelper,
RoomID: roomID,
EventID: eventID,
ExpireIn: expireIn,
@ -55,22 +52,17 @@ const (
getAllScheduledDisappearingMessagesQuery = `
SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1
`
insertDisappearingMessageQuery = `INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`
updateDisappearingMessageExpiryQuery = "UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3"
deleteDisappearingMessageQuery = "DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2"
)
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(duration time.Duration) (messages []*DisappearingMessage) {
rows, err := dmq.db.Query(getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, dmq.New().Scan(rows))
}
return
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
}
type DisappearingMessage struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*DisappearingMessage]
RoomID id.RoomID
EventID id.EventID
@ -78,50 +70,33 @@ type DisappearingMessage struct {
ExpireAt time.Time
}
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
var expireIn int64
var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
msg.ExpireIn = time.Duration(expireIn) * time.Millisecond
if expireAt.Valid {
msg.ExpireAt = time.UnixMilli(expireAt.Int64)
}
return msg
return msg, nil
}
func (msg *DisappearingMessage) Insert(txn dbutil.Execable) {
if txn == nil {
txn = msg.db
}
var expireAt sql.NullInt64
if !msg.ExpireAt.IsZero() {
expireAt.Valid = true
expireAt.Int64 = msg.ExpireAt.UnixMilli()
}
_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt)
if err != nil {
msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)
}
func (msg *DisappearingMessage) sqlVariables() []any {
return []any{msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), dbutil.UnixMilliPtr(msg.ExpireAt)}
}
func (msg *DisappearingMessage) StartTimer() {
func (msg *DisappearingMessage) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
}
func (msg *DisappearingMessage) StartTimer(ctx context.Context) error {
msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second)
_, err := msg.db.Exec("UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err)
}
return msg.qh.Exec(ctx, updateDisappearingMessageExpiryQuery, msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
}
func (msg *DisappearingMessage) Delete() {
_, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2", msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err)
}
func (msg *DisappearingMessage) Delete(ctx context.Context) error {
return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.RoomID, msg.EventID)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,7 @@
package database
import (
"database/sql"
"errors"
"context"
"fmt"
"time"
@ -26,23 +25,19 @@ import (
"go.mau.fi/util/dbutil"
waProto "go.mau.fi/whatsmeow/binary/proto"
"google.golang.org/protobuf/proto"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type HistorySyncQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*HistorySyncConversation]
}
type HistorySyncConversation struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*HistorySyncConversation]
UserID id.UserID
ConversationID string
PortalKey *PortalKey
PortalKey PortalKey
LastMessageTimestamp time.Time
MuteEndTime time.Time
Archived bool
@ -54,18 +49,16 @@ type HistorySyncConversation struct {
UnreadCount uint32
}
func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation {
func newHistorySyncConversation(qh *dbutil.QueryHelper[*HistorySyncConversation]) *HistorySyncConversation {
return &HistorySyncConversation{
db: hsq.db,
log: hsq.log,
PortalKey: &PortalKey{},
qh: qh,
}
}
func (hsq *HistorySyncQuery) NewConversationWithValues(
userID id.UserID,
conversationID string,
portalKey *PortalKey,
portalKey PortalKey,
lastMessageTimestamp,
muteEndTime uint64,
archived bool,
@ -74,10 +67,10 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType,
ephemeralExpiration *uint32,
markedAsUnread bool,
unreadCount uint32) *HistorySyncConversation {
unreadCount uint32,
) *HistorySyncConversation {
return &HistorySyncConversation{
db: hsq.db,
log: hsq.log,
qh: hsq.QueryHelper,
UserID: userID,
ConversationID: conversationID,
PortalKey: portalKey,
@ -94,6 +87,17 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
}
const (
upsertHistorySyncConversationQuery = `
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`
getNMostRecentConversations = `
SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
FROM history_sync_conversation
@ -108,24 +112,19 @@ const (
AND portal_jid=$2
AND portal_receiver=$3
`
deleteAllConversationsQuery = "DELETE FROM history_sync_conversation WHERE user_mxid=$1"
deleteHistorySyncConversationQuery = `
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`
)
func (hsc *HistorySyncConversation) Upsert() {
_, err := hsc.db.Exec(`
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`,
func (hsc *HistorySyncConversation) sqlVariables() []any {
return []any{
hsc.UserID,
hsc.ConversationID,
hsc.PortalKey.JID.String(),
hsc.PortalKey.Receiver.String(),
hsc.PortalKey.JID,
hsc.PortalKey.Receiver,
hsc.LastMessageTimestamp,
hsc.Archived,
hsc.Pinned,
@ -134,14 +133,16 @@ func (hsc *HistorySyncConversation) Upsert() {
hsc.EndOfHistoryTransferType,
hsc.EphemeralExpiration,
hsc.MarkedAsUnread,
hsc.UnreadCount)
if err != nil {
hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
hsc.UnreadCount,
}
}
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
err := row.Scan(
func (hsc *HistorySyncConversation) Upsert(ctx context.Context) error {
return hsc.qh.Exec(ctx, upsertHistorySyncConversationQuery, hsc.sqlVariables()...)
}
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConversation, error) {
return dbutil.ValueOrErr(hsc, row.Scan(
&hsc.UserID,
&hsc.ConversationID,
&hsc.PortalKey.JID,
@ -154,69 +155,59 @@ func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConve
&hsc.EndOfHistoryTransferType,
&hsc.EphemeralExpiration,
&hsc.MarkedAsUnread,
&hsc.UnreadCount)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
hsc.log.Errorln("Database scan failed:", err)
}
return nil
}
return hsc
&hsc.UnreadCount,
))
}
func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
func (hsq *HistorySyncQuery) GetRecentConversations(ctx context.Context, userID id.UserID, n int) ([]*HistorySyncConversation, error) {
nPtr := &n
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
if n < 0 && hsq.db.Dialect == dbutil.Postgres {
if n < 0 && hsq.GetDB().Dialect == dbutil.Postgres {
nPtr = nil
}
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
conversations = append(conversations, hsq.NewConversation().Scan(rows))
}
return
return hsq.QueryMany(ctx, getNMostRecentConversations, userID, nPtr)
}
func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) {
rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
if rows.Next() {
conversation = hsq.NewConversation().Scan(rows)
}
return
func (hsq *HistorySyncQuery) GetConversation(ctx context.Context, userID id.UserID, portalKey PortalKey) (*HistorySyncConversation, error) {
return hsq.QueryOne(ctx, getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
}
func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) {
_, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err)
}
func (hsq *HistorySyncQuery) DeleteAllConversations(ctx context.Context, userID id.UserID) error {
return hsq.Exec(ctx, deleteAllConversationsQuery, userID)
}
const (
getMessagesBetween = `
insertHistorySyncMessageQuery = `
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`
getHistorySyncMessagesBetweenQueryTemplate = `
SELECT data FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
%s
ORDER BY timestamp DESC
%s
`
deleteMessagesBetweenExclusive = `
deleteHistorySyncMessagesBetweenExclusiveQuery = `
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4
`
deleteAllHistorySyncMessagesQuery = "DELETE FROM history_sync_message WHERE user_mxid=$1"
deleteHistorySyncMessagesForPortalQuery = `
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`
conversationHasHistorySyncMessagesQuery = `
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`
)
type HistorySyncMessage struct {
db *Database
log log.Logger
hsq *HistorySyncQuery
UserID id.UserID
ConversationID string
@ -231,8 +222,8 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
return nil, err
}
return &HistorySyncMessage{
db: hsq.db,
log: hsq.log,
hsq: hsq,
UserID: userID,
ConversationID: conversationID,
MessageID: messageID,
@ -241,18 +232,27 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
}, nil
}
func (hsm *HistorySyncMessage) Insert() error {
_, err := hsm.db.Exec(`
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
return err
func (hsm *HistorySyncMessage) Insert(ctx context.Context) error {
return hsm.hsq.Exec(ctx, insertHistorySyncMessageQuery, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
}
func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
func scanWebMessageInfo(rows dbutil.Scannable) (*waProto.WebMessageInfo, error) {
var msgData []byte
err := rows.Scan(&msgData)
if err != nil {
return nil, err
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
return historySyncMsg.GetMessage(), nil
}
func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) ([]*waProto.WebMessageInfo, error) {
whereClauses := ""
args := []interface{}{userID, conversationID}
args := []any{userID, conversationID}
argNum := 3
if startTime != nil {
whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
@ -268,80 +268,35 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
if limit > 0 {
limitClause = fmt.Sprintf("LIMIT %d", limit)
}
query := fmt.Sprintf(getHistorySyncMessagesBetweenQueryTemplate, whereClauses, limitClause)
rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...)
defer rows.Close()
if err != nil || rows == nil {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
hsq.log.Warnfln("Failed to query messages between range: %v", err)
}
return nil
}
var msgData []byte
for rows.Next() {
err = rows.Scan(&msgData)
if err != nil {
hsq.log.Errorfln("Database scan failed: %v", err)
continue
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err)
continue
}
messages = append(messages, historySyncMsg.Message)
}
return
return dbutil.ConvertRowFn[*waProto.WebMessageInfo](scanWebMessageInfo).
NewRowIter(hsq.GetDB().Query(ctx, query, args...)).
AsList()
}
func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
func (hsq *HistorySyncQuery) DeleteMessages(ctx context.Context, userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
newest := messages[0]
beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0)
oldest := messages[len(messages)-1]
afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0)
_, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS)
return err
return hsq.Exec(ctx, deleteHistorySyncMessagesBetweenExclusiveQuery, userID, conversationID, beforeTS, afterTS)
}
func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) {
_, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err)
}
func (hsq *HistorySyncQuery) DeleteAllMessages(ctx context.Context, userID id.UserID) error {
return hsq.Exec(ctx, deleteAllHistorySyncMessagesQuery, userID)
}
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) {
_, err := hsq.db.Exec(`
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, portalKey.JID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err)
}
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
return hsq.Exec(ctx, deleteHistorySyncMessagesForPortalQuery, userID, portalKey.JID)
}
func (hsq *HistorySyncQuery) ConversationHasMessages(userID id.UserID, portalKey PortalKey) (exists bool) {
err := hsq.db.QueryRow(`
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`, userID, portalKey.JID).Scan(&exists)
if err != nil {
hsq.log.Warnfln("Failed to check if any messages are stored for %s/%s: %v", userID, portalKey.JID, err)
}
func (hsq *HistorySyncQuery) ConversationHasMessages(ctx context.Context, userID id.UserID, portalKey PortalKey) (exists bool, err error) {
err = hsq.GetDB().QueryRow(ctx, conversationHasHistorySyncMessagesQuery, userID, portalKey.JID).Scan(&exists)
return
}
func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) {
func (hsq *HistorySyncQuery) DeleteConversation(ctx context.Context, userID id.UserID, jid string) error {
// This will also clear history_sync_message as there's a foreign key constraint
_, err := hsq.db.Exec(`
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, jid)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err)
}
return hsq.Exec(ctx, deleteHistorySyncConversationQuery, userID, jid)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,14 +17,12 @@
package database
import (
"database/sql"
"errors"
"context"
_ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
)
type MediaBackfillRequestStatus int
@ -36,34 +34,46 @@ const (
)
type MediaBackfillRequestQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*MediaBackfillRequest]
}
type MediaBackfillRequest struct {
db *Database
log log.Logger
const (
getAllMediaBackfillRequestsForUserQuery = `
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
deleteAllMediaBackfillRequestsForUserQuery = "DELETE FROM media_backfill_requests WHERE user_mxid=$1"
upsertMediaBackfillRequestQuery = `
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
DO UPDATE SET
media_key=excluded.media_key,
status=excluded.status,
error=excluded.error
`
)
UserID id.UserID
PortalKey *PortalKey
EventID id.EventID
MediaKey []byte
Status MediaBackfillRequestStatus
Error string
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(ctx context.Context, userID id.UserID) ([]*MediaBackfillRequest, error) {
return mbrq.QueryMany(ctx, getAllMediaBackfillRequestsForUserQuery, userID)
}
func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest {
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(ctx context.Context, userID id.UserID) error {
return mbrq.Exec(ctx, deleteAllMediaBackfillRequestsForUserQuery, userID)
}
func newMediaBackfillRequest(qh *dbutil.QueryHelper[*MediaBackfillRequest]) *MediaBackfillRequest {
return &MediaBackfillRequest{
db: mbrq.db,
log: mbrq.log,
PortalKey: &PortalKey{},
qh: qh,
}
}
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
return &MediaBackfillRequest{
db: mbrq.db,
log: mbrq.log,
qh: mbrq.QueryHelper,
UserID: userID,
PortalKey: portalKey,
EventID: eventID,
@ -72,62 +82,25 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID
}
}
const (
getMediaBackfillRequestsForUser = `
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
)
type MediaBackfillRequest struct {
qh *dbutil.QueryHelper[*MediaBackfillRequest]
func (mbr *MediaBackfillRequest) Upsert() {
_, err := mbr.db.Exec(`
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
DO UPDATE SET
media_key=EXCLUDED.media_key,
status=EXCLUDED.status,
error=EXCLUDED.error`,
mbr.UserID,
mbr.PortalKey.JID.String(),
mbr.PortalKey.Receiver.String(),
mbr.EventID,
mbr.MediaKey,
mbr.Status,
mbr.Error)
if err != nil {
mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err)
}
UserID id.UserID
PortalKey PortalKey
EventID id.EventID
MediaKey []byte
Status MediaBackfillRequestStatus
Error string
}
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
mbr.log.Errorln("Database scan failed:", err)
}
return nil
}
return mbr
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) (*MediaBackfillRequest, error) {
return dbutil.ValueOrErr(mbr, row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error))
}
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) {
rows, err := mbrq.db.Query(getMediaBackfillRequestsForUser, userID)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
requests = append(requests, mbrq.newMediaBackfillRequest().Scan(rows))
}
return
func (mbr *MediaBackfillRequest) sqlVariables() []any {
return []any{mbr.UserID, mbr.PortalKey.JID, mbr.PortalKey.Receiver, mbr.EventID, mbr.MediaKey, mbr.Status, mbr.Error}
}
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) {
_, err := mbrq.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID)
if err != nil {
mbrq.log.Warnfln("Failed to delete media backfill requests for %s: %v", userID, err)
}
func (mbr *MediaBackfillRequest) Upsert(ctx context.Context) error {
return mbr.qh.Exec(ctx, upsertMediaBackfillRequestQuery, mbr.sqlVariables()...)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,29 +17,22 @@
package database
import (
"database/sql"
"errors"
"context"
"fmt"
"strings"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type MessageQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*Message]
}
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
log: mq.log,
}
func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
return &Message{qh: qh}
}
const (
@ -67,60 +60,47 @@ const (
SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
`
insertMessageQuery = `
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
markMessageSentQuery = "UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4"
updateMessageMXIDQuery = "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
deleteMessageQuery = "DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3"
)
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
func (mq *MessageQuery) GetAll(ctx context.Context, chat PortalKey) ([]*Message, error) {
return mq.QueryMany(ctx, getAllMessagesQuery, chat.JID, chat.Receiver)
}
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
func (mq *MessageQuery) GetByJID(ctx context.Context, chat PortalKey, jid types.MessageID) (*Message, error) {
return mq.QueryOne(ctx, getMessageByJIDQuery, chat.JID, chat.Receiver, jid)
}
func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
}
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.GetLastInChatBefore(ctx, chat, time.Now().Add(60*time.Second))
}
func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
if msg == nil || msg.Timestamp.IsZero() {
func (mq *MessageQuery) GetLastInChatBefore(ctx context.Context, chat PortalKey, maxTimestamp time.Time) (*Message, error) {
msg, err := mq.QueryOne(ctx, getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())
if msg != nil && msg.Timestamp.IsZero() {
// Old db, we don't know what the last message is.
return nil
msg = nil
}
return msg
return msg, err
}
func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
func (mq *MessageQuery) GetFirstInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.QueryOne(ctx, getFirstMessageInChatQuery, chat.JID, chat.Receiver)
}
func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
}
func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
if row == nil {
return nil
}
return mq.New().Scan(row)
func (mq *MessageQuery) GetMessagesBetween(ctx context.Context, chat PortalKey, minTimestamp, maxTimestamp time.Time) ([]*Message, error) {
return mq.QueryMany(ctx, getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
}
type MessageErrorType string
@ -144,8 +124,7 @@ const (
)
type Message struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*Message]
Chat PortalKey
JID types.MessageID
@ -172,76 +151,49 @@ func (msg *Message) IsFakeJID() bool {
const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s"
func (msg *Message) Scan(row dbutil.Scannable) *Message {
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") {
_, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID)
if err != nil {
msg.log.Errorln("Parsing gallery MXID failed:", err)
return nil, fmt.Errorf("failed to parse gallery MXID: %w", err)
}
}
if ts != 0 {
msg.Timestamp = time.Unix(ts, 0)
}
return msg
return msg, nil
}
func (msg *Message) Insert(txn dbutil.Execable) {
if txn == nil {
txn = msg.db
}
var sender interface{} = msg.Sender
// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
if msg.Sender.IsEmpty() {
sender = ""
}
func (msg *Message) sqlVariables() []any {
mxid := msg.MXID.String()
if msg.GalleryPart != 0 {
mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid)
}
_, err := txn.Exec(`
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
return []any{msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, msg.Sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID}
}
func (msg *Message) MarkSent(ts time.Time) {
func (msg *Message) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
}
func (msg *Message) MarkSent(ctx context.Context, ts time.Time) error {
msg.Sent = true
msg.Timestamp = ts
_, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
return msg.qh.Exec(ctx, markMessageSentQuery, ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
}
func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) {
if txn == nil {
txn = msg.db
}
func (msg *Message) UpdateMXID(ctx context.Context, mxid id.EventID, newType MessageType, newError MessageErrorType) error {
msg.MXID = mxid
msg.Type = newType
msg.Error = newError
_, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6",
mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
return msg.qh.Exec(ctx, updateMessageMXIDQuery, mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
}
func (msg *Message) Delete() {
_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
}
func (msg *Message) Delete(ctx context.Context) error {
return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,28 +17,56 @@
package database
import (
"context"
"fmt"
"strings"
"github.com/lib/pq"
"go.mau.fi/util/dbutil"
)
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
var hash []byte
err = rows.Scan(&id, &hash)
if err != nil {
// return below
} else if len(hash) != 32 {
err = fmt.Errorf("unexpected hash length %d", len(hash))
} else {
hashArr = *(*[32]byte)(hash)
const (
bulkPutPollOptionsQuery = "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
bulkPutPollOptionsQueryTemplate = "($1, $%d, $%d)"
bulkPutPollOptionsQueryPlaceholder = "($1, $2, $3)"
getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
getPollOptionQueryArrayPlaceholder = " = ANY($2)"
)
func init() {
if strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, "meow") == bulkPutPollOptionsQuery {
panic("Bulk insert query placeholder not found")
}
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
return
}
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
type pollOption struct {
id string
hash [32]byte
}
func scanPollOption(rows dbutil.Scannable) (*pollOption, error) {
var hash []byte
var id string
err := rows.Scan(&id, &hash)
if err != nil {
return nil, err
} else if len(hash) != 32 {
return nil, fmt.Errorf("unexpected hash length %d", len(hash))
} else {
return &pollOption{id: id, hash: [32]byte(hash)}, nil
}
}
func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string) error {
args := make([]any, len(opts)*2+1)
placeholders := make([]string, len(opts))
args[0] = msg.MXID
@ -47,72 +75,47 @@ func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
args[i*2+1] = id
hashCopy := hash
args[i*2+2] = hashCopy[:]
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
placeholders[i] = fmt.Sprintf(bulkPutPollOptionsQueryTemplate, i*2+2, i*2+3)
i++
}
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
_, err := msg.db.Exec(query, args...)
if err != nil {
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
}
query := strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, strings.Join(placeholders, ","))
return msg.qh.Exec(ctx, query, args...)
}
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
func getPollOptions[LookupKey any, Key comparable, Value any](
ctx context.Context,
msg *Message,
query string,
things []LookupKey,
getKeyValue func(option *pollOption) (Key, Value),
) (map[Key]Value, error) {
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(hashes)}
if msg.qh.GetDB().Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(things)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
args = make([]any, len(hashes)+1)
query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
args = make([]any, len(things)+1)
args[0] = msg.MXID
for i, hash := range hashes {
args[i+1] = hash
for i, thing := range things {
args[i+1] = thing
}
}
ids := make(map[[32]byte]string, len(hashes))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
break
}
ids[hash] = id
}
}
return ids
return dbutil.RowIterAsMap(
dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(msg.qh.GetDB().Query(ctx, query, args...)),
getKeyValue,
)
}
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(ids)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
args = make([]any, len(ids)+1)
args[0] = msg.MXID
for i, id := range ids {
args[i+1] = id
}
}
hashes := make(map[string][32]byte, len(ids))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
break
}
hashes[id] = hash
}
}
return hashes
func (msg *Message) GetPollOptionIDs(ctx context.Context, hashes [][]byte) (map[[32]byte]string, error) {
return getPollOptions(
ctx, msg, getPollOptionIDsByHashesQuery, hashes,
func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
)
}
func (msg *Message) GetPollOptionHashes(ctx context.Context, ids []string) (map[string][32]byte, error) {
return getPollOptions(
ctx, msg, getPollOptionHashesByIDsQuery, ids,
func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,15 +17,14 @@
package database
import (
"context"
"database/sql"
"fmt"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
)
type PortalKey struct {
@ -53,90 +52,89 @@ func (key PortalKey) String() string {
}
type PortalQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*Portal]
}
func (pq *PortalQuery) New() *Portal {
func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
return &Portal{
db: pq.db,
log: pq.log,
qh: qh,
}
}
const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
}
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
}
func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid)
}
func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
receiver = receiver.ToNonAD()
rows, err := pq.db.Query(`
const (
getAllPortalsQuery = `
SELECT jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space,
first_event_id, next_batch_id, relay_user_id, expiration_time
FROM portal
`
getPortalByJIDQuery = getAllPortalsQuery + " WHERE jid=$1 AND receiver=$2"
getPortalByMXIDQuery = getAllPortalsQuery + " WHERE mxid=$1"
getPrivateChatsWithQuery = getAllPortalsQuery + " WHERE jid=$1"
getPrivateChatsOfQuery = getAllPortalsQuery + " WHERE receiver=$1"
getAllPortalsByParentGroupQuery = getAllPortalsQuery + " WHERE parent_group=$1"
findPrivateChatPortalsNotInSpaceQuery = `
SELECT jid FROM portal
LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL)
`, receiver)
if err != nil {
pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err)
return
} else if rows == nil {
return
}
for rows.Next() {
var key PortalKey
`
insertPortalQuery = `
INSERT INTO portal (
jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space,
first_event_id, next_batch_id, relay_user_id, expiration_time
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`
updatePortalQuery = `
UPDATE portal
SET mxid=$3, name=$4, name_set=$5, topic=$6, topic_set=$7, avatar=$8, avatar_url=$9, avatar_set=$10,
encrypted=$11, last_sync=$12, is_parent=$13, parent_group=$14, in_space=$15,
first_event_id=$16, next_batch_id=$17, relay_user_id=$18, expiration_time=$19
WHERE jid=$1 AND receiver=$2
`
clearPortalInSpaceQuery = "UPDATE portal SET in_space=false WHERE parent_group=$1"
deletePortalQuery = "DELETE FROM portal WHERE jid=$1 AND receiver=$2"
)
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery)
}
func (pq *PortalQuery) GetByJID(ctx context.Context, key PortalKey) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByJIDQuery, key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
}
func (pq *PortalQuery) GetAllByJID(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsWithQuery, jid.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChats(ctx context.Context, receiver types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsOfQuery, receiver.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsByParentGroupQuery, jid)
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver types.JID) (keys []PortalKey, err error) {
receiver = receiver.ToNonAD()
scanFn := func(rows dbutil.Scannable) (key PortalKey, err error) {
key.Receiver = receiver
err = rows.Scan(&key.JID)
if err == nil {
keys = append(keys, key)
}
return
}
return
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
return dbutil.ConvertRowFn[PortalKey](scanFn).
NewRowIter(pq.GetDB().Query(ctx, findPrivateChatPortalsNotInSpaceQuery, receiver)).
AsList()
}
type Portal struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*Portal]
Key PortalKey
MXID id.RoomID
@ -161,15 +159,17 @@ type Portal struct {
ExpirationTime uint32
}
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
var lastSyncTs int64
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
err := row.Scan(
&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet,
&portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted,
&lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace,
&firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime,
)
if err != nil {
if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
if lastSyncTs > 0 {
portal.LastSync = time.Unix(lastSyncTs, 0)
@ -182,96 +182,36 @@ func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
portal.FirstEventID = id.EventID(firstEventID.String)
portal.NextBatchID = id.BatchID(nextBatchID.String)
portal.RelayUserID = id.UserID(relayUserID.String)
return portal
return portal, nil
}
func (portal *Portal) mxidPtr() *id.RoomID {
if len(portal.MXID) > 0 {
return &portal.MXID
func (portal *Portal) sqlVariables() []any {
var lastSyncTS int64
if !portal.LastSync.IsZero() {
lastSyncTS = portal.LastSync.Unix()
}
return nil
}
func (portal *Portal) relayUserPtr() *id.UserID {
if len(portal.RelayUserID) > 0 {
return &portal.RelayUserID
}
return nil
}
func (portal *Portal) parentGroupPtr() *string {
if !portal.ParentGroup.IsEmpty() {
val := portal.ParentGroup.String()
return &val
}
return nil
}
func (portal *Portal) lastSyncTs() int64 {
if portal.LastSync.IsZero() {
return 0
}
return portal.LastSync.Unix()
}
func (portal *Portal) Insert() {
_, err := portal.db.Exec(`
INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
relay_user_id, expiration_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`,
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
portal.relayUserPtr(), portal.ExpirationTime)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
return []any{
portal.Key.JID, portal.Key.Receiver, dbutil.StrPtr(portal.MXID), portal.Name, portal.NameSet,
portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted,
lastSyncTS, portal.IsParent, dbutil.StrPtr(portal.ParentGroup.String()), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), dbutil.StrPtr(portal.RelayUserID), portal.ExpirationTime,
}
}
func (portal *Portal) Update(txn dbutil.Execable) {
if txn == nil {
txn = portal.db
}
_, err := txn.Exec(`
UPDATE portal
SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
WHERE jid=$18 AND receiver=$19
`, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
}
func (portal *Portal) Insert(ctx context.Context) error {
return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
}
func (portal *Portal) Delete() {
txn, err := portal.db.Begin()
if err != nil {
portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
return
}
defer func() {
func (portal *Portal) Update(ctx context.Context) error {
return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
}
func (portal *Portal) Delete(ctx context.Context) error {
return portal.qh.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
err := portal.qh.Exec(ctx, clearPortalInSpaceQuery, portal.Key.JID)
if err != nil {
err = txn.Rollback()
if err != nil {
portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
}
} else if err = txn.Commit(); err != nil {
portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
return err
}
}()
_, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID)
if err != nil {
portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
return
}
_, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
}
return portal.qh.Exec(ctx, deletePortalQuery, portal.Key.JID, portal.Key.Receiver)
})
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,74 +17,70 @@
package database
import (
"context"
"database/sql"
"time"
"go.mau.fi/util/dbutil"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
)
type PuppetQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*Puppet]
}
func (pq *PuppetQuery) New() *Puppet {
func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
return &Puppet{
db: pq.db,
log: pq.log,
qh: qh,
EnablePresence: true,
EnableReceipts: true,
}
}
func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet")
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return
const (
getAllPuppetsQuery = `
SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts
FROM puppet
`
getPuppetByJIDQuery = getAllPuppetsQuery + " WHERE username=$1"
getPuppetByCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid=$1"
getAllPuppetsWithCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid<>''"
insertPuppetQuery = `
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
updatePuppetQuery = `
UPDATE puppet
SET avatar=$2, avatar_url=$3, avatar_set=$4, displayname=$5, name_quality=$6, name_set=$7, contact_info_set=$8,
last_sync=$9, custom_mxid=$10, access_token=$11, next_batch=$12, enable_presence=$13, enable_receipts=$14
WHERE username=$1
`
)
func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) {
return pq.QueryMany(ctx, getAllPuppetsQuery)
}
func (pq *PuppetQuery) Get(jid types.JID) *Puppet {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User)
if row == nil {
return nil
}
return pq.New().Scan(row)
func (pq *PuppetQuery) Get(ctx context.Context, jid types.JID) (*Puppet, error) {
return pq.QueryOne(ctx, getPuppetByJIDQuery, jid.User)
}
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid)
if row == nil {
return nil
}
return pq.New().Scan(row)
func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) {
return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid)
}
func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''")
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return
func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) {
return pq.QueryMany(ctx, getAllPuppetsWithCustomMXIDQuery)
}
type Puppet struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*Puppet]
JID types.JID
Avatar string
@ -103,17 +99,14 @@ type Puppet struct {
EnableReceipts bool
}
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality, lastSync sql.NullInt64
var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool
var username string
err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
if err != nil {
if err != sql.ErrNoRows {
puppet.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
puppet.JID = types.NewJID(username, types.DefaultUserServer)
puppet.Displayname = displayname.String
@ -131,45 +124,30 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
puppet.NextBatch = nextBatch.String
puppet.EnablePresence = enablePresence.Bool
puppet.EnableReceipts = enableReceipts.Bool
return puppet
return puppet, nil
}
func (puppet *Puppet) Insert() {
if puppet.JID.Server != types.DefaultUserServer {
puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID)
return
}
var lastSyncTs int64
func (puppet *Puppet) sqlVariables() []any {
var lastSyncTS int64
if !puppet.LastSync.IsZero() {
lastSyncTs = puppet.LastSync.Unix()
lastSyncTS = puppet.LastSync.Unix()
}
_, err := puppet.db.Exec(`
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
return []any{
puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTS,
puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
puppet.EnablePresence, puppet.EnableReceipts,
)
if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
}
}
func (puppet *Puppet) Update() {
var lastSyncTs int64
if !puppet.LastSync.IsZero() {
lastSyncTs = puppet.LastSync.Unix()
}
_, err := puppet.db.Exec(`
UPDATE puppet
SET displayname=$1, name_quality=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, contact_info_set=$7, last_sync=$8,
custom_mxid=$9, access_token=$10, next_batch=$11, enable_presence=$12, enable_receipts=$13
WHERE username=$14
`, puppet.Displayname, puppet.NameQuality, puppet.NameSet, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.ContactInfoSet,
lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts,
puppet.JID.User)
if err != nil {
puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
func (puppet *Puppet) Insert(ctx context.Context) error {
if puppet.JID.Server != types.DefaultUserServer {
zerolog.Ctx(ctx).Warn().Stringer("jid", puppet.JID).Msg("Not inserting puppet: not a user")
return nil
}
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
}
func (puppet *Puppet) Update(ctx context.Context) error {
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,26 +17,20 @@
package database
import (
"database/sql"
"errors"
"context"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type ReactionQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*Reaction]
}
func (rq *ReactionQuery) New() *Reaction {
return &Reaction{
db: rq.db,
log: rq.log,
}
func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
return &Reaction{qh: qh}
}
const (
@ -55,28 +49,20 @@ const (
DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid
`
deleteReactionQuery = `
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 AND mxid=$5
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4
`
)
func (rq *ReactionQuery) GetByTargetJID(chat PortalKey, jid types.MessageID, sender types.JID) *Reaction {
return rq.maybeScan(rq.db.QueryRow(getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD()))
func (rq *ReactionQuery) GetByTargetJID(ctx context.Context, chat PortalKey, jid types.MessageID, sender types.JID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())
}
func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction {
return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid))
}
func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction {
if row == nil {
return nil
}
return rq.New().Scan(row)
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
}
type Reaction struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*Reaction]
Chat PortalKey
TargetJID types.MessageID
@ -85,35 +71,19 @@ type Reaction struct {
JID types.MessageID
}
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
reaction.log.Errorln("Database scan failed:", err)
}
return nil
}
return reaction
func (reaction *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
return dbutil.ValueOrErr(reaction, row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID))
}
func (reaction *Reaction) Upsert(txn dbutil.Execable) {
func (reaction *Reaction) sqlVariables() []any {
reaction.Sender = reaction.Sender.ToNonAD()
if txn == nil {
txn = reaction.db
}
_, err := txn.Exec(upsertReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID)
if err != nil {
reaction.log.Warnfln("Failed to upsert reaction to %s@%s by %s: %v", reaction.Chat, reaction.TargetJID, reaction.Sender, err)
}
return []any{reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID}
}
func (reaction *Reaction) GetTarget() *Message {
return reaction.db.Message.GetByJID(reaction.Chat, reaction.TargetJID)
func (reaction *Reaction) Upsert(ctx context.Context) error {
return reaction.qh.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...)
}
func (reaction *Reaction) Delete() {
_, err := reaction.db.Exec(deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID)
if err != nil {
reaction.log.Warnfln("Failed to delete reaction %s: %v", reaction.MXID, err)
}
func (reaction *Reaction) Delete(ctx context.Context) error {
return reaction.qh.Exec(ctx, deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender)
}

View File

@ -17,6 +17,7 @@
package upgrades
import (
"context"
"embed"
"errors"
@ -29,7 +30,7 @@ var Table dbutil.UpgradeTable
var rawUpgrades embed.FS
func init() {
Table.Register(-1, 35, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 35, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
})
Table.RegisterFS(rawUpgrades)

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,63 +17,65 @@
package database
import (
"context"
"database/sql"
"sync"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
)
type UserQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*User]
}
func (uq *UserQuery) New() *User {
func newUser(qh *dbutil.QueryHelper[*User]) *User {
return &User{
db: uq.db,
log: uq.log,
qh: qh,
lastReadCache: make(map[PortalKey]time.Time),
inSpaceCache: make(map[PortalKey]bool),
}
}
func (uq *UserQuery) GetAll() (users []*User) {
rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
users = append(users, uq.New().Scan(rows))
}
return
const (
getAllUsersQuery = `SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`
getUserByMXIDQuery = getAllUsersQuery + ` WHERE mxid=$1`
getUserByUsernameQuery = getAllUsersQuery + ` WHERE username=$1`
insertUserQuery = `
INSERT INTO "user" (
mxid, username, agent, device,
management_room, space_room,
phone_last_seen, phone_last_pinged, timezone
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
updateUserQuery = `
UPDATE "user"
SET username=$1, agent=$2, device=$3,
management_room=$4, space_room=$5,
phone_last_seen=$6, phone_last_pinged=$7, timezone=$8
WHERE mxid=$9
`
getUserLastAppStateKeyIDQuery = "SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1"
)
func (uq *UserQuery) GetAll(ctx context.Context) ([]*User, error) {
return uq.QueryMany(ctx, getAllUsersQuery)
}
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
}
func (uq *UserQuery) GetByUsername(username string) *User {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username)
if row == nil {
return nil
}
return uq.New().Scan(row)
func (uq *UserQuery) GetByUsername(ctx context.Context, username string) (*User, error) {
return uq.QueryOne(ctx, getUserByUsernameQuery, username)
}
type User struct {
db *Database
log log.Logger
qh *dbutil.QueryHelper[*User]
MXID id.UserID
JID types.JID
@ -89,20 +91,21 @@ type User struct {
inSpaceCacheLock sync.Mutex
}
func (user *User) Scan(row dbutil.Scannable) *User {
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
var username, timezone sql.NullString
var device, agent sql.NullByte
var device, agent sql.NullInt16
var phoneLastSeen, phoneLastPinged sql.NullInt64
err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone)
if err != nil {
if err != sql.ErrNoRows {
user.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
user.Timezone = timezone.String
if len(username.String) > 0 {
user.JID = types.NewADJID(username.String, agent.Byte, device.Byte)
user.JID = types.JID{
User: username.String,
Device: uint16(device.Int16),
Server: types.DefaultUserServer,
}
}
if phoneLastSeen.Valid {
user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
@ -110,66 +113,34 @@ func (user *User) Scan(row dbutil.Scannable) *User {
if phoneLastPinged.Valid {
user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0)
}
return user
return user, nil
}
func (user *User) usernamePtr() *string {
func (user *User) sqlVariables() []any {
var username *string
var agent, device *uint16
if !user.JID.IsEmpty() {
return &user.JID.User
username = dbutil.StrPtr(user.JID.User)
var zero uint16
agent = &zero
device = dbutil.NumPtr(user.JID.Device)
}
return nil
}
func (user *User) agentPtr() *uint8 {
if !user.JID.IsEmpty() {
zero := uint8(0)
return &zero
}
return nil
}
func (user *User) devicePtr() *uint8 {
if !user.JID.IsEmpty() {
device := uint8(user.JID.Device)
return &device
}
return nil
}
func (user *User) phoneLastSeenPtr() *int64 {
if user.PhoneLastSeen.IsZero() {
return nil
}
ts := user.PhoneLastSeen.Unix()
return &ts
}
func (user *User) phoneLastPingedPtr() *int64 {
if user.PhoneLastPinged.IsZero() {
return nil
}
ts := user.PhoneLastPinged.Unix()
return &ts
}
func (user *User) Insert() {
_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone)
if err != nil {
user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
return []any{
username, agent, device, user.ManagementRoom, user.SpaceRoom,
dbutil.UnixPtr(user.PhoneLastSeen), dbutil.UnixPtr(user.PhoneLastPinged),
user.Timezone, user.MXID,
}
}
func (user *User) Update() {
_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`,
user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID)
if err != nil {
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
}
func (user *User) Insert(ctx context.Context) error {
return user.qh.Exec(ctx, insertUserQuery, user.sqlVariables()...)
}
func (user *User) GetLastAppStateKeyID() ([]byte, error) {
var keyID []byte
err := user.db.QueryRow("SELECT key_id FROM whatsmeow_app_state_sync_keys ORDER BY timestamp DESC LIMIT 1").Scan(&keyID)
return keyID, err
func (user *User) Update(ctx context.Context) error {
return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
}
func (user *User) GetLastAppStateKeyID(ctx context.Context) (keyID []byte, err error) {
err = user.qh.GetDB().QueryRow(ctx, getUserLastAppStateKeyIDQuery, user.JID).Scan(&keyID)
return
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,69 +17,97 @@
package database
import (
"context"
"database/sql"
"errors"
"time"
"github.com/rs/zerolog"
)
func (user *User) GetLastReadTS(portal PortalKey) time.Time {
const (
getLastReadTSQuery = "SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setLastReadTSQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`
getIsInSpaceQuery = "SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setIsInSpaceQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`
)
func (user *User) GetLastReadTS(ctx context.Context, portal PortalKey) time.Time {
user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock()
if cached, ok := user.lastReadCache[portal]; ok {
return cached
}
var ts int64
err := user.db.QueryRow("SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&ts)
var parsedTS time.Time
err := user.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, user.MXID, portal.JID, portal.Receiver).Scan(&ts)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
user.log.Warnfln("Failed to scan last read timestamp from user portal table: %v", err)
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query last read timestamp")
return parsedTS
}
if ts == 0 {
user.lastReadCache[portal] = time.Time{}
} else {
user.lastReadCache[portal] = time.Unix(ts, 0)
if ts != 0 {
parsedTS = time.Unix(ts, 0)
}
user.lastReadCache[portal] = parsedTS
return user.lastReadCache[portal]
}
func (user *User) SetLastReadTS(portal PortalKey, ts time.Time) {
func (user *User) SetLastReadTS(ctx context.Context, portal PortalKey, ts time.Time) {
user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock()
_, err := user.db.Exec(`
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`, user.MXID, portal.JID, portal.Receiver, ts.Unix())
_, err := user.qh.GetDB().Exec(ctx, setLastReadTSQuery, user.MXID, portal.JID, portal.Receiver, ts.Unix())
if err != nil {
user.log.Warnfln("Failed to update last read timestamp: %v", err)
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update last read timestamp")
} else {
user.log.Debugfln("Set last read timestamp of %s in %s to %d", user.MXID, portal.String(), ts.Unix())
zerolog.Ctx(ctx).Debug().
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Time("last_read_ts", ts).
Msg("Updated last read timestamp of portal")
user.lastReadCache[portal] = ts
}
}
func (user *User) IsInSpace(portal PortalKey) bool {
func (user *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock()
if cached, ok := user.inSpaceCache[portal]; ok {
return cached
}
var inSpace bool
err := user.db.QueryRow("SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
err := user.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
user.log.Warnfln("Failed to scan in space status from user portal table: %v", err)
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query in space status")
return false
}
user.inSpaceCache[portal] = inSpace
return inSpace
}
func (user *User) MarkInSpace(portal PortalKey) {
func (user *User) MarkInSpace(ctx context.Context, portal PortalKey) {
user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock()
_, err := user.db.Exec(`
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`, user.MXID, portal.JID, portal.Receiver)
_, err := user.qh.GetDB().Exec(ctx, setIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver)
if err != nil {
user.log.Warnfln("Failed to update in space status: %v", err)
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update in space status")
} else {
user.inSpaceCache[portal] = true
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,11 @@
package main
import (
"context"
"fmt"
"time"
"go.mau.fi/util/dbutil"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@ -28,47 +29,74 @@ import (
"maunium.net/go/mautrix-whatsapp/database"
)
func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
func (portal *Portal) MarkDisappearing(ctx context.Context, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
if expiresIn == 0 {
return
}
expiresAt := startsAt.Add(expiresIn)
msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, expiresIn, expiresAt)
msg.Insert(txn)
err := msg.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to insert disappearing message")
}
if expiresAt.Before(time.Now().Add(1 * time.Hour)) {
go portal.sleepAndDelete(msg)
go portal.sleepAndDelete(context.WithoutCancel(ctx), msg)
}
}
func (br *WABridge) SleepAndDeleteUpcoming() {
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
func (br *WABridge) SleepAndDeleteUpcoming(ctx context.Context) {
msgs, err := br.DB.DisappearingMessage.GetUpcomingScheduled(ctx, 1*time.Hour)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get upcoming disappearing messages")
return
}
for _, msg := range msgs {
portal := br.GetPortalByMXID(msg.RoomID)
if portal == nil {
msg.Delete()
err = msg.Delete(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", msg.EventID).
Msg("Failed to delete disappearing message row with no portal")
}
} else {
go portal.sleepAndDelete(msg)
go portal.sleepAndDelete(ctx, msg)
}
}
}
func (portal *Portal) sleepAndDelete(msg *database.DisappearingMessage) {
func (portal *Portal) sleepAndDelete(ctx context.Context, msg *database.DisappearingMessage) {
if _, alreadySleeping := portal.currentlySleepingToDelete.LoadOrStore(msg.EventID, true); alreadySleeping {
return
}
defer portal.currentlySleepingToDelete.Delete(msg.EventID)
log := zerolog.Ctx(ctx)
sleepTime := msg.ExpireAt.Sub(time.Now())
portal.log.Debugfln("Sleeping for %s to make %s disappear", sleepTime, msg.EventID)
log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Dur("sleep_time", sleepTime).
Msg("Sleeping before making message disappear")
time.Sleep(sleepTime)
_, err := portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{
_, err := portal.MainIntent().RedactEvent(ctx, msg.RoomID, msg.EventID, mautrix.ReqRedact{
Reason: "Message expired",
TxnID: fmt.Sprintf("mxwa_disappear_%s", msg.EventID),
})
if err != nil {
portal.log.Warnfln("Failed to make %s disappear: %v", msg.EventID, err)
log.Err(err).
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Failed to make event disappear")
} else {
portal.log.Debugfln("Disappeared %s", msg.EventID)
log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Disappeared event")
}
err = msg.Delete(ctx)
if err != nil {
log.Err(err).Msg("Failed to delete disapperaing message row in database after redacting event")
}
msg.Delete()
}

View File

@ -17,12 +17,14 @@
package main
import (
"context"
"fmt"
"html"
"regexp"
"sort"
"strings"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/event"
@ -104,22 +106,27 @@ func NewFormatter(bridge *WABridge) *Formatter {
return formatter
}
func (formatter *Formatter) getMatrixInfoByJID(roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
func (formatter *Formatter) getMatrixInfoByJID(ctx context.Context, roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
mxid = puppet.MXID
displayname = puppet.Displayname
}
if user := formatter.bridge.GetUserByJID(jid); user != nil {
mxid = user.MXID
member := formatter.bridge.StateStore.GetMember(roomID, user.MXID)
if len(member.Displayname) > 0 {
member, err := formatter.bridge.StateStore.GetMember(ctx, roomID, user.MXID)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("user_id", user.MXID).
Msg("Failed to get member profile from state store")
} else if len(member.Displayname) > 0 {
displayname = member.Displayname
}
}
return
}
func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
output := html.EscapeString(content.Body)
for regex, replacement := range formatter.waReplString {
output = regex.ReplaceAllString(output, replacement)
@ -145,7 +152,7 @@ func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.Messa
// TODO lid support?
continue
}
mxid, displayname := formatter.getMatrixInfoByJID(roomID, jid)
mxid, displayname := formatter.getMatrixInfoByJID(ctx, roomID, jid)
number := "@" + jid.User
output = strings.ReplaceAll(output, number, fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname))
content.Body = strings.ReplaceAll(content.Body, number, displayname)

45
go.mod
View File

@ -1,25 +1,25 @@
module maunium.net/go/mautrix-whatsapp
go 1.20
go 1.21
require (
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.19
github.com/prometheus/client_golang v1.17.0
github.com/rs/zerolog v1.31.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/prometheus/client_golang v1.19.0
github.com/rs/zerolog v1.32.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/tidwall/gjson v1.17.0
go.mau.fi/util v0.2.1
github.com/tidwall/gjson v1.17.1
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e
go.mau.fi/webp v0.1.0
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611
golang.org/x/image v0.14.0
golang.org/x/net v0.19.0
google.golang.org/protobuf v1.31.0
maunium.net/go/maulogger/v2 v2.4.1
maunium.net/go/mautrix v0.16.2
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/image v0.15.0
golang.org/x/net v0.22.0
google.golang.org/protobuf v1.33.0
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa
)
require (
@ -27,24 +27,27 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.6.0 // indirect
github.com/yuin/goldmark v1.7.0 // indirect
go.mau.fi/libsignal v0.1.0 // indirect
go.mau.fi/zeroconfig v0.1.2 // indirect
golang.org/x/crypto v0.16.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
maunium.net/go/mauflag v1.0.0 // indirect
)
//replace go.mau.fi/util => ../../Go/go-util
//replace maunium.net/go/mautrix => ../mautrix-go

100
go.sum
View File

@ -1,6 +1,9 @@
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c h1:WqjRVgUO039eiISCjsZC4F9onOEV93DJAk6v33rsZzY=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c/go.mod h1:b9FFm9y4mEm36G8ytVmS1vkNzJa0KepmcdVY+qf7qRU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
@ -9,18 +12,18 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -30,78 +33,75 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c=
go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I=
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw=
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c=
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e h1:e1jDj/MjleSS5r9DMRbuCZYKy5Rr+sbsu8eWjtLqrGk=
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo=
go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA=
go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 h1:+teJYCOK6M4Kn2TYCj29levhHVwnJTmgCtEXLtgwQtM=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735/go.mod h1:5xTtHNaZpGni6z6aE1iEopjW7wNgsKcolZxZrOujK9M=
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 h1:QOGjCIh2WEfkgX/38KLjnNof79GWx0T+KLrhTHiws3s=
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462/go.mod h1:lQHbhaG/fI+6hfGqz5Vzn2OBJBEZ05H0kCP6iJXriN4=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 h1:qCEDpW1G+vcj3Y7Fy52pEM1AWm3abj8WimGYejI3SC4=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa h1:TLSWIAWKIWxLghgzWfp7o92pVCcFR3yLsArc0s/tsMs=
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa/go.mod h1:0sfLB2ejW+lhgio4UlZMmn5i9SuZ8mxFkonFSamrfTE=

View File

@ -17,6 +17,7 @@
package main
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
@ -24,11 +25,11 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/variationselector"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/util/variationselector"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
@ -64,9 +65,8 @@ func (user *User) handleHistorySyncsLoop() {
if batchSend {
// Start the backfill queue.
user.BackfillQueue = &BackfillQueue{
BackfillQuery: user.bridge.DB.Backfill,
BackfillQuery: user.bridge.DB.BackfillQueue,
reCheckChannels: []chan bool{},
log: user.log.Sub("BackfillQueue"),
}
forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward}
@ -109,80 +109,113 @@ func (user *User) handleHistorySyncsLoop() {
const EnqueueBackfillsDelay = 30 * time.Second
func (user *User) enqueueAllBackfills() {
nMostRecent := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
if len(nMostRecent) > 0 {
user.log.Infofln("%v has passed since the last history sync blob, enqueueing backfills for %d chats", EnqueueBackfillsDelay, len(nMostRecent))
// Find the portals for all the conversations.
portals := []*Portal{}
for _, conv := range nMostRecent {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
user.log.Warnfln("Failed to parse chat JID '%s' in history sync: %v", conv.ConversationID, err)
continue
}
portals = append(portals, user.GetPortalByJID(jid))
}
user.EnqueueImmediateBackfills(portals)
user.EnqueueForwardBackfills(portals)
user.EnqueueDeferredBackfills(portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
log := user.zlog.With().
Str("method", "User.enqueueAllBackfills").
Logger()
ctx := log.WithContext(context.TODO())
nMostRecent, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
if err != nil {
log.Err(err).Msg("Failed to get recent history sync conversations from database")
return
} else if len(nMostRecent) == 0 {
return
}
log.Info().
Int("chat_count", len(nMostRecent)).
Msg("Enqueueing backfills for recent chats in history sync")
// Find the portals for all the conversations.
portals := make([]*Portal, 0, len(nMostRecent))
for _, conv := range nMostRecent {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).Str("conversation_id", conv.ConversationID).Msg("Failed to parse chat JID in history sync")
continue
}
portals = append(portals, user.GetPortalByJID(jid))
}
user.EnqueueImmediateBackfills(ctx, portals)
user.EnqueueForwardBackfills(ctx, portals)
user.EnqueueDeferredBackfills(ctx, portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
}
func (user *User) backfillAll() {
conversations := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, -1)
if len(conversations) > 0 {
user.zlog.Info().
Int("conversation_count", len(conversations)).
Msg("Probably received all history sync blobs, now backfilling conversations")
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
bridgedCount := 0
// Find the portals for all the conversations.
for _, conv := range conversations {
jid, err := types.ParseJID(conv.ConversationID)
log := user.zlog.With().
Str("method", "User.backfillAll").
Logger()
ctx := log.WithContext(context.TODO())
conversations, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, -1)
if err != nil {
log.Err(err).Msg("Failed to get history sync conversations from database")
return
} else if len(conversations) == 0 {
return
}
log.Info().
Int("conversation_count", len(conversations)).
Msg("Probably received all history sync blobs, now backfilling conversations")
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
bridgedCount := 0
// Find the portals for all the conversations.
for _, conv := range conversations {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).
Str("conversation_id", conv.ConversationID).
Msg("Failed to parse chat JID in history sync")
continue
}
portal := user.GetPortalByJID(jid)
if portal.MXID != "" {
log.Debug().
Str("portal_jid", portal.Key.JID.String()).
Msg("Chat already has a room, deleting messages from database")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
user.zlog.Warn().Err(err).
Str("conversation_id", conv.ConversationID).
Msg("Failed to parse chat JID in history sync")
continue
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
Msg("Failed to delete history sync conversation with existing portal from database")
}
portal := user.GetPortalByJID(jid)
if portal.MXID != "" {
user.zlog.Debug().
Str("portal_jid", portal.Key.JID.String()).
Msg("Chat already has a room, deleting messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
bridgedCount++
} else if !user.bridge.DB.HistorySync.ConversationHasMessages(user.MXID, portal.Key) {
user.zlog.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
} else if limit < 0 || bridgedCount < limit {
bridgedCount++
err = portal.CreateMatrixRoom(user, nil, nil, true, true)
if err != nil {
user.zlog.Err(err).Msg("Failed to create Matrix room for backfill")
}
bridgedCount++
} else if hasMessages, err := user.bridge.DB.HistorySync.ConversationHasMessages(ctx, user.MXID, portal.Key); err != nil {
log.Err(err).Str("portal_jid", portal.Key.JID.String()).Msg("Failed to check if chat has messages in history sync")
} else if !hasMessages {
log.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
Msg("Failed to delete history sync conversation with no messages from database")
}
} else if limit < 0 || bridgedCount < limit {
bridgedCount++
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, true)
if err != nil {
log.Err(err).Msg("Failed to create Matrix room for backfill")
}
}
}
}
func (portal *Portal) legacyBackfill(user *User) {
func (portal *Portal) legacyBackfill(ctx context.Context, user *User) {
defer portal.latestEventBackfillLock.Unlock()
// This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room.
if portal.latestEventBackfillLock.TryLock() {
panic("legacyBackfill() called without locking latestEventBackfillLock")
}
// TODO use portal.zlog instead of user.zlog
log := user.zlog.With().
Str("portal_jid", portal.Key.JID.String()).
Str("action", "legacy backfill").
Logger()
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, portal.Key)
messages := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
log := zerolog.Ctx(ctx).With().Str("action", "legacy backfill").Logger()
ctx = log.WithContext(ctx)
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get history sync conversation data for backfill")
return
}
messages, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
if err != nil {
log.Err(err).Msg("Failed to get history sync messages for backfill")
return
}
log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database")
for i := len(messages) - 1; i >= 0; i-- {
msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i])
@ -194,18 +227,26 @@ func (portal *Portal) legacyBackfill(user *User) {
Msg("Dropping historical message due to parse error")
continue
}
portal.handleMessage(user, msgEvt, true)
ctx := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger().
WithContext(ctx)
portal.handleMessage(ctx, user, msgEvt, true)
}
if conv != nil {
isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead := !isUnread || isTooOld
if shouldMarkAsRead {
user.markSelfReadFull(portal)
user.markSelfReadFull(ctx, portal)
}
}
log.Debug().Msg("Backfill complete, deleting leftover messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
log.Info().Msg("Backfill complete, deleting leftover messages from database")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
log.Err(err).Msg("Failed to delete history sync conversation from database after backfill")
}
}
func (user *User) dailyMediaRequestLoop() {
@ -224,29 +265,49 @@ func (user *User) dailyMediaRequestLoop() {
if requestStartTime.Before(now) {
requestStartTime = requestStartTime.AddDate(0, 0, 1)
}
log := user.zlog.With().
Str("action", "daily media request loop").
Logger()
ctx := log.WithContext(context.Background())
// Wait to start the loop
user.log.Infof("Waiting until %s to do media retry requests", requestStartTime)
log.Info().Time("start_loop_at", requestStartTime).Msg("Waiting until start time to do media retry requests")
time.Sleep(time.Until(requestStartTime))
for {
mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID)
user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests))
mediaBackfillRequests, err := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(ctx, user.MXID)
if err != nil {
log.Err(err).Msg("Failed to get media retry requests")
} else if len(mediaBackfillRequests) > 0 {
log.Info().Int("media_request_count", len(mediaBackfillRequests)).Msg("Sending media retry requests")
// Send all of the media backfill requests for the user at once
for _, req := range mediaBackfillRequests {
portal := user.GetPortalByJID(req.PortalKey.JID)
_, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey)
if err != nil {
user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID)
req.Status = database.MediaBackfillRequestStatusRequestFailed
req.Error = err.Error()
} else {
user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID)
req.Status = database.MediaBackfillRequestStatusRequested
// Send all the media backfill requests for the user at once
for _, req := range mediaBackfillRequests {
portal := user.GetPortalByJID(req.PortalKey.JID)
_, err = portal.requestMediaRetry(ctx, user, req.EventID, req.MediaKey)
if err != nil {
log.Err(err).
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Failed to send media retry request")
req.Status = database.MediaBackfillRequestStatusRequestFailed
req.Error = err.Error()
} else {
log.Debug().
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Sent media retry request")
req.Status = database.MediaBackfillRequestStatusRequested
}
req.MediaKey = nil
err = req.Upsert(ctx)
if err != nil {
log.Err(err).
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Failed to save status of media retry request")
}
}
req.MediaKey = nil
req.Upsert()
}
// Wait for 24 hours before making requests again
@ -254,20 +315,29 @@ func (user *User) dailyMediaRequestLoop() {
}
}
func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) {
func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTask, conv *database.HistorySyncConversation, portal *Portal) {
portal.backfillLock.Lock()
defer portal.backfillLock.Unlock()
log := zerolog.Ctx(ctx)
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(portal.MXID, user.MXID) {
portal.ensureUserInvited(user)
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(ctx, portal.MXID, user.MXID) {
portal.ensureUserInvited(ctx, user)
}
backfillState := user.bridge.DB.Backfill.GetBackfillState(user.MXID, &portal.Key)
backfillState, err := user.bridge.DB.BackfillState.GetBackfillState(ctx, user.MXID, portal.Key)
if backfillState == nil {
backfillState = user.bridge.DB.Backfill.NewBackfillState(user.MXID, &portal.Key)
backfillState = user.bridge.DB.BackfillState.NewBackfillState(user.MXID, portal.Key)
}
backfillState.SetProcessingBatch(true)
defer backfillState.SetProcessingBatch(false)
err = backfillState.SetProcessingBatch(ctx, true)
if err != nil {
log.Err(err).Msg("Failed to mark batch as being processed")
}
defer func() {
err = backfillState.SetProcessingBatch(ctx, false)
if err != nil {
log.Err(err).Msg("Failed to mark batch as no longer being processed")
}
}()
var timeEnd *time.Time
var forward, shouldMarkAsRead bool
@ -275,17 +345,27 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
if req.BackfillType == database.BackfillForward {
// TODO this overrides the TimeStart set when enqueuing the backfill
// maybe the enqueue should instead include the prev event ID
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get newest message in chat")
return
}
start := lastMessage.Timestamp.Add(1 * time.Second)
req.TimeStart = &start
// Sending events at the end of the room (= latest events)
forward = true
} else {
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key)
firstMessage, err := portal.bridge.DB.Message.GetFirstInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get oldest message in chat")
return
}
if firstMessage != nil {
end := firstMessage.Timestamp.Add(-1 * time.Second)
timeEnd = &end
user.log.Debugfln("Limiting backfill to end at %v", end)
log.Debug().
Time("oldest_message_ts", firstMessage.Timestamp).
Msg("Limiting backfill to messages older than oldest message")
} else {
// Portal is empty -> events are latest
forward = true
@ -303,45 +383,48 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead = !isUnread || isTooOld
}
allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
allMsgs, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
sendDisappearedNotice := false
// If expired messages are on, and a notice has not been sent to this chat
// about it having disappeared messages at the conversation timestamp, send
// a notice indicating so.
if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 {
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get last message in chat to check if disappeared notice should be sent")
}
if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) {
sendDisappearedNotice = true
}
}
if !sendDisappearedNotice && len(allMsgs) == 0 {
user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID)
log.Debug().Msg("Not backfilling chat: no bridgeable messages found")
return
}
if len(portal.MXID) == 0 {
user.log.Debugln("Creating portal for", portal.Key.JID, "as part of history sync handling")
err := portal.CreateMatrixRoom(user, nil, nil, true, false)
log.Debug().Msg("Creating portal for chat as part of history sync handling")
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, false)
if err != nil {
user.log.Errorfln("Failed to create room for %s during backfill: %v", portal.Key.JID, err)
log.Err(err).Msg("Failed to create room for chat during backfill")
return
}
}
// Update the backfill status here after the room has been created.
portal.updateBackfillStatus(backfillState)
portal.updateBackfillStatus(ctx, backfillState)
if sendDisappearedNotice {
user.log.Debugfln("Sending notice to %s that there are disappeared messages ending at %v", portal.Key.JID, conv.LastMessageTimestamp)
resp, err := portal.sendMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
log.Debug().Time("last_message_time", conv.LastMessageTimestamp).
Msg("Sending notice that there are disappeared messages in the chat")
resp, err := portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice,
Body: portal.formatDisappearingMessageNotice(),
}, nil, conv.LastMessageTimestamp.UnixMilli())
if err != nil {
portal.log.Errorln("Error sending disappearing messages notice event")
log.Err(err).Msg("Failed to send disappeared messages notice event")
return
}
@ -353,12 +436,18 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
msg.SenderMXID = portal.MainIntent().UserID
msg.Sent = true
msg.Type = database.MsgFake
msg.Insert(nil)
user.markSelfReadFull(portal)
err = msg.Insert(ctx)
if err != nil {
log.Err(err).Msg("Failed to save fake message entry for disappearing message timer in backfill")
}
user.markSelfReadFull(ctx, portal)
return
}
user.log.Infofln("Backfilling %d messages in %s, %d messages at a time (queue ID: %d)", len(allMsgs), portal.Key.JID, req.MaxBatchEvents, req.QueueID)
log.Info().
Int("message_count", len(allMsgs)).
Int("max_batch_events", req.MaxBatchEvents).
Msg("Backfilling messages")
toBackfill := allMsgs[0:]
for len(toBackfill) > 0 {
var msgs []*waProto.WebMessageInfo
@ -372,14 +461,14 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
if len(msgs) > 0 {
time.Sleep(time.Duration(req.BatchDelay) * time.Second)
user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID)
portal.backfill(user, msgs, forward, shouldMarkAsRead)
log.Debug().Int("batch_message_count", len(msgs)).Msg("Backfilling message batch")
portal.backfill(ctx, user, msgs, forward, shouldMarkAsRead)
}
}
user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID)
err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs)
log.Debug().Int("message_count", len(allMsgs)).Msg("Finished backfilling messages in queue entry")
err = user.bridge.DB.HistorySync.DeleteMessages(ctx, user.MXID, conv.ConversationID, allMsgs)
if err != nil {
user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err)
log.Err(err).Msg("Failed to delete history sync messages after backfilling")
}
if req.TimeStart == nil {
@ -399,8 +488,11 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
// beginning of time.
backfillState.FirstExpectedTimestamp = 0
}
backfillState.Upsert()
portal.updateBackfillStatus(backfillState)
err = backfillState.Upsert(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill state as completed in database")
}
portal.updateBackfillStatus(ctx, backfillState)
}
}
@ -408,13 +500,13 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
if evt == nil || evt.SyncType == nil {
return
}
log := user.bridge.ZLog.With().
log := user.zlog.With().
Str("method", "User.storeHistorySync").
Str("user_id", user.MXID.String()).
Str("sync_type", evt.GetSyncType().String()).
Uint32("chunk_order", evt.GetChunkOrder()).
Uint32("progress", evt.GetProgress()).
Logger()
ctx := log.WithContext(context.TODO())
if evt.GetGlobalSettings() != nil {
log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync")
}
@ -466,7 +558,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues(
user.MXID,
conv.GetId(),
&portal.Key,
portal.Key,
getConversationTimestamp(conv),
conv.GetMuteEndTime(),
conv.GetArchived(),
@ -476,7 +568,10 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
conv.EphemeralExpiration,
conv.GetMarkedAsUnread(),
conv.GetUnreadCount())
historySyncConversation.Upsert()
err := historySyncConversation.Upsert(ctx)
if err != nil {
log.Err(err).Msg("Failed to insert history sync conversation into database")
}
}
var minTime, maxTime time.Time
@ -521,7 +616,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
Msg("Failed to save historical message")
continue
}
err = message.Insert()
err = message.Insert(ctx)
if err != nil {
log.Error().Err(err).
Int("msg_index", i).
@ -570,15 +665,20 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 {
return convTs
}
func (user *User) EnqueueImmediateBackfills(portals []*Portal) {
func (user *User) EnqueueImmediateBackfills(ctx context.Context, portals []*Portal) {
for priority, portal := range portals {
maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents
initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0)
initialBackfill.Insert()
initialBackfill := user.bridge.DB.BackfillQueue.NewWithValues(user.MXID, database.BackfillImmediate, priority, portal.Key, nil, maxMessages, maxMessages, 0)
err := initialBackfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert immediate backfill into database")
}
}
}
func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Portal) {
numPortals := len(portals)
for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred {
for portalIdx, portal := range portals {
@ -587,22 +687,36 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo)
startDate = &startDaysAgo
}
backfillMessages := user.bridge.DB.Backfill.NewWithValues(
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
backfillMessages.Insert()
backfillMessages := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
err := backfillMessages.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert deferred backfill into database")
}
}
}
}
func (user *User) EnqueueForwardBackfills(portals []*Portal) {
func (user *User) EnqueueForwardBackfills(ctx context.Context, portals []*Portal) {
for priority, portal := range portals {
lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key)
if lastMsg == nil {
lastMsg, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to get last message in chat to enqueue forward backfill")
} else if lastMsg == nil {
continue
}
backfill := user.bridge.DB.Backfill.NewWithValues(
user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0)
backfill.Insert()
backfill := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillForward, priority, portal.Key, &lastMsg.Timestamp, -1, -1, 0)
err = backfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert forward backfill into database")
}
}
}
@ -619,12 +733,11 @@ func (portal *Portal) deterministicEventID(sender types.JID, messageID types.Mes
}
var (
PortalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType}
)
func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) *mautrix.RespBeeperBatchSend {
func (portal *Portal) backfill(ctx context.Context, source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) {
log := zerolog.Ctx(ctx)
var req mautrix.ReqBeeperBatchSend
var infos []*wrappedInfo
@ -633,7 +746,10 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
req.MarkReadBy = source.MXID
}
portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward)
log.Info().
Bool("forward", isForward).
Int("message_count", len(messages)).
Msg("Processing history sync message batch")
// The messages are ordered newest to oldest, so iterate them in reverse order.
for i := len(messages) - 1; i >= 0; i-- {
webMsg := messages[i]
@ -641,11 +757,16 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
if err != nil {
continue
}
log := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger()
ctx := log.WithContext(ctx)
msgType := getMessageType(msgEvt.Message)
if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
if msgType != "ignore" {
portal.log.Debugfln("Skipping message %s with unknown type in backfill", msgEvt.Info.ID)
log.Debug().Msg("Skipping message with unknown type in backfill")
}
continue
}
@ -654,85 +775,83 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
if !existingContact.Found || existingContact.PushName == "" {
changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName())
if err != nil {
source.log.Errorfln("Failed to save push name of %s from historical message in device store: %v", msgEvt.Info.Sender, err)
log.Err(err).Msg("Failed to save push name from historical message to device store")
} else if changed {
source.log.Debugfln("Got push name %s for %s from historical message", webMsg.GetPushName(), msgEvt.Info.Sender)
log.Debug().Str("push_name", webMsg.GetPushName()).Msg("Got push name from historical message")
}
}
}
puppet := portal.getMessagePuppet(source, &msgEvt.Info)
puppet := portal.getMessagePuppet(ctx, source, &msgEvt.Info)
if puppet == nil {
continue
}
converted := portal.convertMessage(puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
converted := portal.convertMessage(ctx, puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
if converted == nil {
portal.log.Debugfln("Skipping unsupported message %s in backfill", msgEvt.Info.ID)
log.Debug().Msg("Skipping unsupported message in backfill")
continue
}
if converted.ReplyTo != nil {
portal.SetReply(msgEvt.Info.ID, converted.Content, converted.ReplyTo, true)
portal.SetReply(ctx, converted.Content, converted.ReplyTo, true)
}
err = portal.appendBatchEvents(source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
err = portal.appendBatchEvents(ctx, source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
if err != nil {
portal.log.Errorfln("Error handling message %s during backfill: %v", msgEvt.Info.ID, err)
log.Err(err).Msg("Failed to handle message in backfill")
}
}
portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events))
log.Info().Int("event_count", len(req.Events)).Msg("Made Matrix events from messages in batch")
if len(req.Events) == 0 {
return nil
return
}
resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &req)
resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &req)
if err != nil {
portal.log.Errorln("Error batch sending messages:", err)
return nil
} else {
txn, err := portal.bridge.DB.Begin()
if err != nil {
portal.log.Errorln("Failed to start transaction to save batch messages:", err)
return nil
}
portal.finishBatch(txn, resp.EventIDs, infos)
err = txn.Commit()
if err != nil {
portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
return nil
}
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(source, resp.EventIDs, infos)
}
return resp
log.Err(err).Msg("Failed to send batch of messages")
return
}
err = portal.bridge.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
return portal.finishBatch(ctx, resp.EventIDs, infos)
})
if err != nil {
log.Err(err).Msg("Failed to save message batch to database")
return
}
log.Info().Msg("Successfully sent backfill batch")
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(context.TODO(), source, resp.EventIDs, infos)
}
}
func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
func (portal *Portal) requestMediaRetries(ctx context.Context, source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
for i, info := range infos {
if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil {
switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod {
case config.MediaRequestMethodImmediate:
err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey)
if err != nil {
portal.log.Warnfln("Failed to send post-backfill media retry request for %s: %v", info.ID, err)
portal.zlog.Err(err).Str("message_id", info.ID).Msg("Failed to send post-backfill media retry request")
} else {
portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID)
portal.zlog.Debug().Str("message_id", info.ID).Msg("Sent post-backfill media retry request")
}
case config.MediaRequestMethodLocalTime:
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey)
req.Upsert()
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, portal.Key, eventIDs[i], info.MediaKey)
err := req.Upsert(ctx)
if err != nil {
portal.zlog.Err(err).
Stringer("event_id", eventIDs[i]).
Msg("Failed to upsert media backfill request")
}
}
}
}
}
func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
if portal.bridge.Config.Bridge.CaptionInMessage {
converted.MergeCaption()
}
mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
mainEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
if err != nil {
return err
}
@ -750,7 +869,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
ExpiresIn: converted.ExpiresIn,
}
if converted.Caption != nil {
captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
captionEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
if err != nil {
return err
}
@ -762,7 +881,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
}
if converted.MultiEvent != nil {
for i, subEvtContent := range converted.MultiEvent {
subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
subEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
if err != nil {
return err
}
@ -771,7 +890,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
}
}
for _, reaction := range raw.GetReactions() {
reactionEvent, reactionInfo := portal.wrapBatchReaction(source, reaction, mainEvt.ID, info.Timestamp)
reactionEvent, reactionInfo := portal.wrapBatchReaction(ctx, source, reaction, mainEvt.ID, info.Timestamp)
if reactionEvent != nil {
*eventsArray = append(*eventsArray, reactionEvent)
*infoArray = append(*infoArray, &wrappedInfo{
@ -785,7 +904,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
return nil
}
func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
var senderJID types.JID
if reaction.GetKey().GetFromMe() {
senderJID = source.JID.ToNonAD()
@ -807,7 +926,7 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
ID: reaction.GetKey().GetId(),
Timestamp: mainEventTS,
}
puppet := portal.getMessagePuppet(source, reactionInfo)
puppet := portal.getMessagePuppet(ctx, source, reactionInfo)
if puppet == nil {
return
}
@ -834,12 +953,12 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
return
}
func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
wrappedContent := event.Content{
Parsed: content,
Raw: extraContent,
}
newEventType, err := portal.encrypt(intent, &wrappedContent, eventType)
newEventType, err := portal.encrypt(ctx, intent, &wrappedContent, eventType)
if err != nil {
return nil, err
}
@ -853,37 +972,37 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
}, nil
}
func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, infos []*wrappedInfo) {
func (portal *Portal) finishBatch(ctx context.Context, eventIDs []id.EventID, infos []*wrappedInfo) error {
for i, info := range infos {
if info == nil {
continue
}
eventID := eventIDs[i]
portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
portal.markHandled(ctx, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
if info.Type == database.MsgReaction {
portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
portal.upsertReaction(ctx, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
}
if info.ExpiresIn > 0 {
portal.MarkDisappearing(txn, eventID, info.ExpiresIn, info.ExpirationStart)
portal.MarkDisappearing(ctx, eventID, info.ExpiresIn, info.ExpirationStart)
}
}
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
return nil
}
func (portal *Portal) updateBackfillStatus(backfillState *database.BackfillState) {
func (portal *Portal) updateBackfillStatus(ctx context.Context, backfillState *database.BackfillState) {
backfillStatus := "backfilling"
if backfillState.BackfillComplete {
backfillStatus = "complete"
}
_, err := portal.bridge.Bot.SendStateEvent(portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
_, err := portal.bridge.Bot.SendStateEvent(ctx, portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
"status": backfillStatus,
"first_timestamp": backfillState.FirstExpectedTimestamp * 1000,
})
if err != nil {
portal.log.Errorln("Error sending backfill status event:", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill status event to room")
}
}

62
main.go
View File

@ -17,6 +17,7 @@
package main
import (
"context"
_ "embed"
"net/http"
"net/url"
@ -26,15 +27,18 @@ import (
"sync"
"time"
"github.com/rs/zerolog"
waLog "go.mau.fi/whatsmeow/util/log"
"google.golang.org/protobuf/proto"
"go.mau.fi/util/configupgrade"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/util/configupgrade"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status"
@ -91,7 +95,7 @@ func (br *WABridge) Init() {
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
Analytics.log = br.Log.Sub("Analytics")
Analytics.log = br.ZLog.With().Str("component", "analytics").Logger()
Analytics.url = (&url.URL{
Scheme: "https",
Host: br.Config.Analytics.Host,
@ -100,23 +104,20 @@ func (br *WABridge) Init() {
Analytics.key = br.Config.Analytics.Token
Analytics.userID = br.Config.Analytics.UserID
if Analytics.IsEnabled() {
Analytics.log.Infoln("Analytics metrics are enabled")
if Analytics.userID != "" {
Analytics.log.Infoln("Overriding analytics user_id with %v", Analytics.userID)
}
Analytics.log.Info().Str("override_user_id", Analytics.userID).Msg("Analytics metrics are enabled")
}
br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database"))
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), &waLogger{br.Log.Sub("Database").Sub("WhatsApp")})
br.DB = database.New(br.Bridge.DB)
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), waLog.Zerolog(br.ZLog.With().Str("db_section", "whatsmeow").Logger()))
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
ss := br.Config.Bridge.Provisioning.SharedSecret
if len(ss) > 0 && ss != "disable" {
br.Provisioning = &ProvisioningAPI{bridge: br}
br.Provisioning = &ProvisioningAPI{bridge: br, log: br.ZLog.With().Str("component", "provisioning").Logger()}
}
br.Formatter = NewFormatter(br)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.ZLog.With().Str("component", "metrics").Logger(), br.DB)
br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
@ -148,11 +149,10 @@ func (br *WABridge) Init() {
func (br *WABridge) Start() {
err := br.WAContainer.Upgrade()
if err != nil {
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade whatsmeow database")
os.Exit(15)
}
if br.Provisioning != nil {
br.Log.Debugln("Initializing provisioning API")
br.Provisioning.Init()
}
go br.CheckWhatsAppUpdate()
@ -166,30 +166,40 @@ func (br *WABridge) Start() {
}
func (br *WABridge) CheckWhatsAppUpdate() {
br.Log.Debugfln("Checking for WhatsApp web update")
br.ZLog.Debug().Msg("Checking for WhatsApp web update")
resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
if err != nil {
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
br.ZLog.Warn().Err(err).Msg("Failed to check for WhatsApp web update")
return
}
if store.GetWAVersion() == resp.ParsedVersion {
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
br.ZLog.Debug().Msg("Bridge is using latest WhatsApp web protocol")
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
if resp.IsBelowHard || resp.IsBroken {
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.ZLog.Warn().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore")
} else if resp.IsBelowSoft {
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.ZLog.Info().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
} else {
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.ZLog.Debug().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
}
} else {
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
br.ZLog.Debug().Msg("Bridge is using newer than latest WhatsApp web protocol")
}
}
func (br *WABridge) Loop() {
ctx := br.ZLog.With().Str("action", "background loop").Logger().WithContext(context.TODO())
for {
br.SleepAndDeleteUpcoming()
br.SleepAndDeleteUpcoming(ctx)
time.Sleep(1 * time.Hour)
br.WarnUsersAboutDisconnection()
}
@ -199,14 +209,14 @@ func (br *WABridge) WarnUsersAboutDisconnection() {
br.usersLock.Lock()
for _, user := range br.usersByUsername {
if user.IsConnected() && !user.PhoneRecentlySeen(true) {
go user.sendPhoneOfflineWarning()
go user.sendPhoneOfflineWarning(context.TODO())
}
}
br.usersLock.Unlock()
}
func (br *WABridge) StartUsers() {
br.Log.Debugln("Starting users")
br.ZLog.Debug().Msg("Starting users")
foundAnySessions := false
for _, user := range br.GetAllUsers() {
if !user.JID.IsEmpty() {
@ -217,13 +227,13 @@ func (br *WABridge) StartUsers() {
if !foundAnySessions {
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil))
}
br.Log.Debugln("Starting custom puppets")
br.ZLog.Debug().Msg("Starting custom puppets")
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
go func(puppet *Puppet) {
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
puppet.zlog.Debug().Stringer("custom_mxid", puppet.CustomMXID).Msg("Starting double puppet")
err := puppet.StartCustomMXID(true)
if err != nil {
puppet.log.Errorln("Failed to start custom puppet:", err)
puppet.zlog.Err(err).Stringer("custom_mxid", puppet.CustomMXID).Msg("Failed to start double puppet")
}
}(loopuppet)
}
@ -235,7 +245,7 @@ func (br *WABridge) Stop() {
if user.Client == nil {
continue
}
br.Log.Debugln("Disconnecting", user.MXID)
user.zlog.Debug().Msg("Disconnecting user")
user.Client.Disconnect()
close(user.historySyncs)
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,8 +17,10 @@
package main
import (
"context"
"fmt"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix"
@ -35,77 +37,89 @@ func (br *WABridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.User,
puppet := brGhost.(*Puppet)
key := database.NewPortalKey(puppet.JID, inviter.JID)
portal := br.GetPortalByJID(key)
log := br.ZLog.With().
Str("action", "create private portal").
Stringer("target_room_id", roomID).
Stringer("inviter_mxid", inviter.MXID).
Stringer("invitee_jid", puppet.JID).
Logger()
ctx := log.WithContext(context.TODO())
if len(portal.MXID) == 0 {
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
return
}
ok := portal.ensureUserInvited(inviter)
ok := portal.ensureUserInvited(ctx, inviter)
if !ok {
br.Log.Warnfln("Failed to invite %s to existing private chat portal %s with %s. Redirecting portal to new room...", inviter.MXID, portal.MXID, puppet.JID)
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
log.Warn().Msg("Failed to invite user to existing private chat portal. Redirecting portal to new room...")
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
return
}
intent := puppet.DefaultIntent()
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%[1]s](https://matrix.to/#/%[1]s)", portal.MXID)
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Config.Homeserver.Domain).MatrixToURL())
errorContent := format.RenderMarkdown(errorMessage, true, false)
_, _ = intent.SendMessageEvent(roomID, event.EventMessage, errorContent)
br.Log.Debugfln("Leaving private chat room %s as %s after accepting invite from %s as we already have chat with the user", roomID, puppet.MXID, inviter.MXID)
_, _ = intent.LeaveRoom(roomID)
_, _ = intent.SendMessageEvent(ctx, roomID, event.EventMessage, errorContent)
log.Debug().Msg("Leaving private chat room from invite as we already have chat with the user")
_, _ = intent.LeaveRoom(ctx, roomID)
}
func (br *WABridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
func (br *WABridge) createPrivatePortalFromInvite(ctx context.Context, roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
log := zerolog.Ctx(ctx)
// TODO check if room is already encrypted
var existingEncryption event.EncryptionEventContent
var encryptionEnabled bool
err := portal.MainIntent().StateEvent(roomID, event.StateEncryption, "", &existingEncryption)
err := portal.MainIntent().StateEvent(ctx, roomID, event.StateEncryption, "", &existingEncryption)
if err != nil {
portal.log.Warnfln("Failed to check if encryption is enabled in private chat room %s", roomID)
log.Err(err).Msg("Failed to check if encryption is enabled")
} else {
encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1
}
portal.MXID = roomID
portal.updateLogger()
portal.Topic = PrivateChatTopic
portal.Name = puppet.Displayname
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.log.Infofln("Created private chat portal in %s after invite from %s", roomID, inviter.MXID)
log.Info().Msg("Created private chat portal from invite")
intent := puppet.DefaultIntent()
if br.Config.Bridge.Encryption.Default || encryptionEnabled {
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
_, err = intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
if err != nil {
portal.log.Warnln("Failed to invite bridge bot to enable e2be:", err)
log.Err(err).Msg("Failed to invite bridge bot to enable e2be")
}
err = br.Bot.EnsureJoined(roomID)
err = br.Bot.EnsureJoined(ctx, roomID)
if err != nil {
portal.log.Warnln("Failed to join as bridge bot to enable e2be:", err)
log.Err(err).Msg("Failed to join as bridge bot to enable e2be")
}
if !encryptionEnabled {
_, err = intent.SendStateEvent(roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
_, err = intent.SendStateEvent(ctx, roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil {
portal.log.Warnln("Failed to enable e2be:", err)
log.Err(err).Msg("Failed to enable e2be")
}
}
br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin)
portal.Encrypted = true
}
_, _ = portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic)
_, _ = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, portal.Topic)
if portal.shouldSetDMRoomMetadata() {
_, err = portal.MainIntent().SetRoomName(portal.MXID, portal.Name)
_, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, portal.Name)
portal.NameSet = err == nil
_, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL)
_, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL)
portal.AvatarSet = err == nil
}
portal.Update(nil)
portal.UpdateBridgeInfo()
_, _ = intent.SendNotice(roomID, "Private chat portal created")
err = portal.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal to database after creating from invite")
}
portal.UpdateBridgeInfo(ctx)
_, _ = intent.SendNotice(ctx, roomID, "Private chat portal created")
}
func (br *WABridge) HandlePresence(evt *event.Event) {
func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) {
user := br.GetUserByMXIDIfExists(evt.Sender)
if user == nil || !user.IsLoggedIn() {
return
@ -119,15 +133,15 @@ func (br *WABridge) HandlePresence(evt *event.Event) {
presence := types.PresenceAvailable
if evt.Content.AsPresence().Presence != event.PresenceOnline {
presence = types.PresenceUnavailable
user.log.Debugln("Marking offline")
user.zlog.Debug().Msg("Marking offline")
} else {
user.log.Debugln("Marking online")
user.zlog.Debug().Msg("Marking online")
}
user.lastPresence = presence
if user.Client.Store.PushName != "" {
err := user.Client.SendPresence(presence)
if err != nil {
user.log.Warnln("Failed to set presence:", err)
user.zlog.Err(err).Msg("Failed to set presence")
}
}
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -23,7 +23,7 @@ import (
"sync"
"time"
log "maunium.net/go/maulogger/v2"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow"
@ -123,7 +123,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
}
}
func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType string, confirmed bool, editID id.EventID) id.EventID {
func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, err error, confirmed bool, editID id.EventID) id.EventID {
if !portal.bridge.Config.Bridge.MessageErrorNotices {
return ""
}
@ -131,6 +131,21 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
if confirmed {
certainty = "was not"
}
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err)
if errors.Is(err, errMessageTakingLong) {
msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType)
@ -144,15 +159,15 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
} else {
content.SetReply(evt)
}
resp, err := portal.sendMainIntentMessage(content)
resp, err := portal.sendMainIntentMessage(ctx, content)
if err != nil {
portal.log.Warnfln("Failed to send bridging error message:", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to send bridging error message")
return ""
}
return resp.EventID
}
func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
if !portal.bridge.Config.Bridge.MessageStatusEvents {
return
}
@ -179,75 +194,56 @@ func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, de
content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err)
content.Error = err.Error()
}
_, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content)
_, err = intent.SendMessageEvent(ctx, portal.MXID, event.BeeperMessageStatus, &content)
if err != nil {
portal.log.Warnln("Failed to send message status event:", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to send message status event")
}
}
func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) {
func (portal *Portal) sendDeliveryReceipt(ctx context.Context, eventID id.EventID) {
if portal.bridge.Config.Bridge.DeliveryReceipts {
err := portal.bridge.Bot.SendReceipt(portal.MXID, eventID, event.ReceiptTypeRead, nil)
err := portal.bridge.Bot.SendReceipt(ctx, portal.MXID, eventID, event.ReceiptTypeRead, nil)
if err != nil {
portal.log.Debugfln("Failed to send delivery receipt for %s: %v", eventID, err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to mark message as read by bot (Matrix-side delivery receipt)")
}
}
}
func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string, ms *metricSender) {
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
evtDescription := evt.ID.String()
if evt.Type == event.EventRedaction {
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
}
func (portal *Portal) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, ms *metricSender) {
origEvtID := evt.ID
if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil {
origEvtID = retryMeta.OriginalEventID
}
if err != nil {
level := log.LevelError
level := zerolog.ErrorLevel
if part == "Ignoring" {
level = log.LevelDebug
level = zerolog.DebugLevel
}
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
zerolog.Ctx(ctx).WithLevel(level).Err(err).Msg(part + " Matrix event")
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode)
portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum())
if sendNotice {
ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID()))
ms.setNoticeID(portal.sendErrorMessage(ctx, evt, err, isCertain, ms.getNoticeID()))
}
portal.sendStatusEvent(origEvtID, evt.ID, err, nil)
portal.sendStatusEvent(ctx, origEvtID, evt.ID, err, nil)
} else {
portal.log.Debugfln("Handled Matrix %s %s", msgType, evtDescription)
portal.sendDeliveryReceipt(evt.ID)
zerolog.Ctx(ctx).Debug().Msg("Successfully handled Matrix event")
portal.sendDeliveryReceipt(ctx, evt.ID)
portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum())
var deliveredTo *[]id.UserID
if portal.IsPrivateChat() {
deliveredTo = &[]id.UserID{}
}
portal.sendStatusEvent(origEvtID, evt.ID, nil, deliveredTo)
portal.sendStatusEvent(ctx, origEvtID, evt.ID, nil, deliveredTo)
if prevNotice := ms.popNoticeID(); prevNotice != "" {
_, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{
_, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, prevNotice, mautrix.ReqRedact{
Reason: "error resolved",
})
}
}
if ms != nil {
portal.log.Debugfln("Timings for %s: %s", evt.ID, ms.timings.String())
zerolog.Ctx(ctx).Debug().Object("timings", ms.timings).Msg("Matrix event timings")
}
}
@ -264,47 +260,16 @@ type messageTimings struct {
totalSend time.Duration
}
func niceRound(dur time.Duration) time.Duration {
switch {
case dur < time.Millisecond:
return dur
case dur < time.Second:
return dur.Round(100 * time.Microsecond)
default:
return dur.Round(time.Millisecond)
}
}
func (mt *messageTimings) String() string {
mt.initReceive = niceRound(mt.initReceive)
mt.decrypt = niceRound(mt.decrypt)
mt.portalQueue = niceRound(mt.portalQueue)
mt.totalReceive = niceRound(mt.totalReceive)
mt.implicitRR = niceRound(mt.implicitRR)
mt.preproc = niceRound(mt.preproc)
mt.convert = niceRound(mt.convert)
mt.whatsmeow.Queue = niceRound(mt.whatsmeow.Queue)
mt.whatsmeow.Marshal = niceRound(mt.whatsmeow.Marshal)
mt.whatsmeow.GetParticipants = niceRound(mt.whatsmeow.GetParticipants)
mt.whatsmeow.GetDevices = niceRound(mt.whatsmeow.GetDevices)
mt.whatsmeow.GroupEncrypt = niceRound(mt.whatsmeow.GroupEncrypt)
mt.whatsmeow.PeerEncrypt = niceRound(mt.whatsmeow.PeerEncrypt)
mt.whatsmeow.Send = niceRound(mt.whatsmeow.Send)
mt.whatsmeow.Resp = niceRound(mt.whatsmeow.Resp)
mt.whatsmeow.Retry = niceRound(mt.whatsmeow.Retry)
mt.totalSend = niceRound(mt.totalSend)
whatsmeowTimings := "N/A"
if mt.totalSend > 0 {
format := "queue: %[1]s, marshal: %[2]s, ske: %[3]s, pcp: %[4]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
if mt.whatsmeow.GetParticipants == 0 && mt.whatsmeow.GroupEncrypt == 0 {
format = "queue: %[1]s, marshal: %[2]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
}
if mt.whatsmeow.Retry > 0 {
format += ", retry: %[9]s"
}
whatsmeowTimings = fmt.Sprintf(format, mt.whatsmeow.Queue, mt.whatsmeow.Marshal, mt.whatsmeow.GroupEncrypt, mt.whatsmeow.GetParticipants, mt.whatsmeow.GetDevices, mt.whatsmeow.PeerEncrypt, mt.whatsmeow.Send, mt.whatsmeow.Resp, mt.whatsmeow.Retry)
}
return fmt.Sprintf("BRIDGE: receive: %s, decrypt: %s, queue: %s, total hs->portal: %s, implicit rr: %s -- PORTAL: preprocess: %s, convert: %s, total send: %s -- WHATSMEOW: %s", mt.initReceive, mt.decrypt, mt.implicitRR, mt.portalQueue, mt.totalReceive, mt.preproc, mt.convert, mt.totalSend, whatsmeowTimings)
func (mt *messageTimings) MarshalZerologObject(e *zerolog.Event) {
e.Dur("init_receive", mt.initReceive).
Dur("decrypt", mt.decrypt).
Dur("implicit_rr", mt.implicitRR).
Dur("portal_queue", mt.portalQueue).
Dur("total_receive", mt.totalReceive).
Dur("preproc", mt.preproc).
Dur("convert", mt.convert).
Object("whatsmeow", mt.whatsmeow).
Dur("total_send", mt.totalSend)
}
type metricSender struct {
@ -345,13 +310,13 @@ func (ms *metricSender) setNoticeID(evtID id.EventID) {
}
}
func (ms *metricSender) sendMessageMetrics(evt *event.Event, err error, part string, completed bool) {
func (ms *metricSender) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, completed bool) {
ms.lock.Lock()
defer ms.lock.Unlock()
if !completed && ms.completed {
return
}
ms.portal.sendMessageMetrics(evt, err, part, ms)
ms.portal.sendMessageMetrics(ctx, evt, err, part, ms)
ms.retryNum++
ms.completed = completed
}

View File

@ -18,6 +18,7 @@ package main
import (
"context"
"errors"
"net/http"
"runtime/debug"
"strconv"
@ -27,7 +28,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "maunium.net/go/maulogger/v2"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
@ -40,7 +41,7 @@ import (
type MetricsHandler struct {
db *database.Database
server *http.Server
log log.Logger
log zerolog.Logger
running bool
ctx context.Context
@ -70,7 +71,7 @@ type MetricsHandler struct {
loggedInStateLock sync.Mutex
}
func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler {
func NewMetricsHandler(address string, log zerolog.Logger, db *database.Database) *MetricsHandler {
portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "whatsapp_portals_total",
Help: "Number of portal rooms on Matrix",
@ -232,31 +233,31 @@ func (mh *MetricsHandler) TrackConnectionState(jid types.JID, connected bool) {
func (mh *MetricsHandler) updateStats() {
start := time.Now()
var puppetCount int
err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
if err != nil {
mh.log.Warnln("Failed to scan number of puppets:", err)
mh.log.Err(err).Msg("Failed to scan number of puppets")
} else {
mh.puppetCount.Set(float64(puppetCount))
}
var userCount int
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
if err != nil {
mh.log.Warnln("Failed to scan number of users:", err)
mh.log.Err(err).Msg("Failed to scan number of users")
} else {
mh.userCount.Set(float64(userCount))
}
var messageCount int
err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
if err != nil {
mh.log.Warnln("Failed to scan number of messages:", err)
mh.log.Err(err).Msg("Failed to scan number of messages")
} else {
mh.messageCount.Set(float64(messageCount))
}
var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int
err = mh.db.QueryRowContext(mh.ctx, `
err = mh.db.QueryRow(mh.ctx, `
SELECT
COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals,
COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals,
@ -265,7 +266,7 @@ func (mh *MetricsHandler) updateStats() {
FROM portal WHERE mxid<>''
`).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount)
if err != nil {
mh.log.Warnln("Failed to scan number of portals:", err)
mh.log.Err(err).Msg("Failed to scan number of portals")
} else {
mh.encryptedGroupCount.Set(float64(encryptedGroupCount))
mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount))
@ -279,7 +280,10 @@ func (mh *MetricsHandler) startUpdatingStats() {
defer func() {
err := recover()
if err != nil {
mh.log.Fatalfln("Panic in metric updater: %v\n%s", err, string(debug.Stack()))
mh.log.WithLevel(zerolog.PanicLevel).
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
Interface(zerolog.ErrorFieldName, err).
Msg("Panic in metric updater")
}
}()
ticker := time.Tick(10 * time.Second)
@ -299,8 +303,8 @@ func (mh *MetricsHandler) Start() {
go mh.startUpdatingStats()
err := mh.server.ListenAndServe()
mh.running = false
if err != nil && err != http.ErrServerClosed {
mh.log.Fatalln("Error in metrics listener:", err)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
mh.log.Err(err).Msg("Error in metrics listener")
}
}
@ -311,6 +315,6 @@ func (mh *MetricsHandler) Stop() {
mh.stopRecorder()
err := mh.server.Close()
if err != nil {
mh.log.Errorln("Error closing metrics listener:", err)
mh.log.Err(err).Msg("Failed to close metrics listener")
}
}

1991
portal.go

File diff suppressed because it is too large Load Diff

View File

@ -17,41 +17,40 @@
package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"strings"
"time"
"github.com/beeper/libserv/pkg/requestlog"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/whatsmeow/appstate"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/id"
)
type ProvisioningAPI struct {
bridge *WABridge
log log.Logger
log zerolog.Logger
}
func (prov *ProvisioningAPI) Init() {
prov.log = prov.bridge.Log.Sub("Provisioning")
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
prov.log.Debug().Str("base_path", prov.bridge.Config.Bridge.Provisioning.Prefix).Msg("Enabling provisioning API")
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
r.Use(hlog.NewHandler(prov.log))
r.Use(requestlog.AccessLogger(true))
r.Use(prov.AuthMiddleware)
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@ -73,7 +72,7 @@ func (prov *ProvisioningAPI) Init() {
prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost)
if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints {
prov.log.Debugln("Enabling debug API at /debug")
prov.log.Debug().Msg("Enabling debug API at /debug")
r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.AuthMiddleware)
r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
@ -83,26 +82,6 @@ func (prov *ProvisioningAPI) Init() {
r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost)
}
type responseWrap struct {
http.ResponseWriter
statusCode int
}
var _ http.Hijacker = (*responseWrap)(nil)
func (rw *responseWrap) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}
func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("response does not implement http.Hijacker")
}
return hijacker.Hijack()
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
@ -119,7 +98,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
auth = auth[len("Bearer "):]
}
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
prov.log.Infof("Authentication token does not match shared secret")
hlog.FromRequest(r).Debug().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Authentication token does not match shared secret",
"errcode": "M_FORBIDDEN",
@ -128,11 +107,12 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
}
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
start := time.Now()
wWrap := &responseWrap{w, 200}
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user)))
duration := time.Now().Sub(start).Seconds()
prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode)
if user != nil {
hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("user_id", user.MXID)
})
}
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
})
}
@ -157,7 +137,7 @@ func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Reques
return
}
user.DeleteConnection()
user.DeleteSession()
user.DeleteSession(r.Context())
jsonResponse(w, http.StatusOK, Response{true, "Session information purged"})
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
}
@ -245,7 +225,7 @@ func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request
ErrCode: "no session",
})
} else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil {
prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err)
hlog.FromRequest(r).Err(err).Msg("Failed to fetch all contacts")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching contact list",
ErrCode: "failed to get contacts",
@ -282,7 +262,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
if r.Method == http.MethodPost {
err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true")
if err != nil {
prov.log.Errorfln("Failed to resync %s's groups: %v", user.MXID, err)
hlog.FromRequest(r).Err(err).Msg("Failed to resync groups")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while resyncing groups",
ErrCode: "failed to sync groups",
@ -291,7 +271,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
}
}
if groups, err := user.getCachedGroupList(); err != nil {
prov.log.Errorfln("Failed to fetch %s's groups: %v", user.MXID, err)
hlog.FromRequest(r).Err(err).Msg("Failed to fetch group list")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching group list",
ErrCode: "failed to get groups",
@ -368,17 +348,17 @@ func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
// resolveIdentifier already responded with an error
return
}
portal, puppet, justCreated, err := user.StartPM(jid, "provisioning API PM")
portal, puppet, justCreated, err := user.StartPM(r.Context(), jid, "provisioning API PM")
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
}
status := http.StatusOK
statusCode := http.StatusOK
if justCreated {
status = http.StatusCreated
statusCode = http.StatusCreated
}
jsonResponse(w, status, PortalInfo{
jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID,
OtherUser: &OtherUserInfo{
JID: puppet.JID,
@ -449,29 +429,30 @@ func (prov *ProvisioningAPI) OpenGroup(w http.ResponseWriter, r *http.Request) {
ErrCode: "invalid group id",
})
} else if info, err := user.Client.GetGroupInfo(jid); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info by JID")
// TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup)
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to get group info: %v", err),
ErrCode: "error getting group info",
})
} else {
prov.log.Debugln("Importing", jid, "for", user.MXID)
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Importing group chat for user")
portal := user.GetPortalByJID(info.JID)
status := http.StatusOK
statusCode := http.StatusOK
if len(portal.MXID) == 0 {
err = portal.CreateMatrixRoom(user, info, nil, true, true)
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
return
}
status = http.StatusCreated
statusCode = http.StatusCreated
}
jsonResponse(w, status, PortalInfo{
jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID,
GroupInfo: info,
JustCreated: status == http.StatusCreated,
JustCreated: statusCode == http.StatusCreated,
})
}
}
@ -495,6 +476,7 @@ func (prov *ProvisioningAPI) resolveGroupInvite(w http.ResponseWriter, r *http.R
ErrCode: "invalid invite link",
})
} else {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info from link")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to fetch group info with link: %v", err),
ErrCode: "error getting group info",
@ -530,29 +512,30 @@ func (prov *ProvisioningAPI) JoinGroup(w http.ResponseWriter, r *http.Request) {
}()
inviteCode, _ := mux.Vars(r)["inviteCode"]
if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to join group")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to join group: %v", err),
ErrCode: "error joining group",
})
} else {
prov.log.Debugln(user.MXID, "successfully joined group", jid)
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Successfully joined group")
portal := user.GetPortalByJID(jid)
status := http.StatusOK
statusCode := http.StatusOK
if len(portal.MXID) == 0 {
time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically
err = portal.CreateMatrixRoom(user, info, nil, true, true)
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
return
}
status = http.StatusCreated
statusCode = http.StatusCreated
}
jsonResponse(w, status, PortalInfo{
jsonResponse(w, statusCode, PortalInfo{
RoomID: portal.MXID,
GroupInfo: info,
JustCreated: status == http.StatusCreated,
JustCreated: statusCode == http.StatusCreated,
})
}
}
@ -616,7 +599,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
} else {
err := user.Client.Logout()
if err != nil {
user.log.Warnln("Error while logging out:", err)
hlog.FromRequest(r).Err(err).Msg("Unknown error while logging out")
if !force {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Unknown error while logging out: %v", err),
@ -632,7 +615,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
user.bridge.Metrics.TrackConnectionState(user.JID, false)
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
user.DeleteSession()
user.DeleteSession(r.Context())
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
}
@ -646,16 +629,17 @@ var upgrader = websocket.Upgrader{
func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
log := hlog.FromRequest(r)
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
prov.log.Errorln("Failed to upgrade connection to websocket:", err)
log.Err(err).Msg("Failed to upgrade connection to websocket")
return
}
defer func() {
err := c.Close()
if err != nil {
user.log.Debugln("Error closing websocket:", err)
log.Debug().Err(err).Msg("Error closing websocket")
}
}()
@ -670,23 +654,26 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
}()
ctx, cancel := context.WithCancel(context.Background())
c.SetCloseHandler(func(code int, text string) error {
user.log.Debugfln("Login websocket closed (%d), cancelling login", code)
log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login")
cancel()
return nil
})
if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" {
user.log.Debug("Setting timezone to %s", userTimezone)
log.Debug().Str("timezone", userTimezone).Msg("Updating user timezone")
user.Timezone = userTimezone
user.Update()
err = user.Update(r.Context())
if err != nil {
log.Err(err).Msg("Failed to save user after updating timezone")
}
} else {
user.log.Debug("No timezone provided in request")
log.Debug().Msg("No timezone provided in request")
}
qrChan, err := user.Login(ctx)
expiryTime := time.Now().Add(160 * time.Second)
if err != nil {
user.log.Errorln("Failed to log in from provisioning API:", err)
log.Err(err).Msg("Failed to log in via provisioning API")
if errors.Is(err, ErrAlreadyLoggedIn) {
go user.Connect()
_ = c.WriteJSON(Error{
@ -704,7 +691,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
if phoneNum != "" {
pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)")
if err != nil {
user.zlog.Err(err).Msg("Failed to start phone code login")
log.Err(err).Msg("Failed to start phone code login")
_ = c.WriteJSON(Error{
Error: "Failed to request pairing code",
ErrCode: "code error",
@ -712,6 +699,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
go user.DeleteConnection()
return
} else {
log.Debug().Msg("Started phone number login")
_ = c.WriteJSON(map[string]any{
"pairing_code": pairingCode,
"timeout": int(time.Until(expiryTime).Seconds()),
@ -719,7 +707,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
}
}
user.log.Debugln("Started login via provisioning API")
log.Debug().Msg("Started login via provisioning API")
Analytics.Track(user.MXID, "$login_start")
for {
@ -728,7 +716,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
switch evt.Event {
case whatsmeow.QRChannelSuccess.Event:
jid := user.Client.Store.ID
user.log.Debugln("Successful login as", jid, "via provisioning API")
log.Debug().Stringer("jid", jid).Msg("Successful login via provisioning API")
Analytics.Track(user.MXID, "$login_success")
_ = c.WriteJSON(map[string]interface{}{
"success": true,
@ -737,7 +725,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
"platform": user.Client.Store.Platform,
})
case whatsmeow.QRChannelTimeout.Event:
user.log.Debugln("Login via provisioning API timed out")
log.Debug().Msg("Login via provisioning API timed out")
errCode := "login timed out"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{
@ -745,7 +733,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode,
})
case whatsmeow.QRChannelErrUnexpectedEvent.Event:
user.log.Debugln("Login via provisioning API failed due to unexpected event")
log.Debug().Msg("Login via provisioning API failed due to unexpected event")
errCode := "unexpected event"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{
@ -753,7 +741,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode,
})
case whatsmeow.QRChannelClientOutdated.Event:
user.log.Debugln("Login via provisioning API failed due to outdated client")
log.Debug().Msg("Login via provisioning API failed due to outdated client")
errCode := "bridge outdated"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{

137
puppet.go
View File

@ -17,15 +17,15 @@
package main
import (
"context"
"fmt"
"regexp"
"sync"
"time"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
@ -59,6 +59,7 @@ func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
}
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
ctx := context.TODO()
jid = jid.ToNonAD()
if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer
@ -69,11 +70,19 @@ func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
defer br.puppetsLock.Unlock()
puppet, ok := br.puppets[jid]
if !ok {
dbPuppet := br.DB.Puppet.Get(jid)
dbPuppet, err := br.DB.Puppet.Get(ctx, jid)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get puppet from database")
return nil
}
if dbPuppet == nil {
dbPuppet = br.DB.Puppet.New()
dbPuppet.JID = jid
dbPuppet.Insert()
err = dbPuppet.Insert(ctx)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to insert new puppet to database")
return nil
}
}
puppet = br.NewPuppet(dbPuppet)
br.puppets[puppet.JID] = puppet
@ -89,7 +98,10 @@ func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
defer br.puppetsLock.Unlock()
puppet, ok := br.puppetsByCustomMXID[mxid]
if !ok {
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid)
if err != nil {
br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get puppet by custom mxid from database")
}
if dbPuppet == nil {
return nil
}
@ -137,14 +149,18 @@ func (puppet *Puppet) GetMXID() id.UserID {
}
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID(context.TODO()))
}
func (br *WABridge) GetAllPuppets() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll(context.TODO()))
}
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet, err error) []*Puppet {
if err != nil {
br.ZLog.Err(err).Msg("Error getting puppets from database")
return nil
}
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
output := make([]*Puppet, len(dbPuppets))
@ -175,7 +191,7 @@ func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{
Puppet: dbPuppet,
bridge: br,
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
zlog: br.ZLog.With().Stringer("puppet_jid", dbPuppet.JID).Logger(),
MXID: br.FormatPuppetMXID(dbPuppet.JID),
}
@ -185,7 +201,7 @@ type Puppet struct {
*database.Puppet
bridge *WABridge
log log.Logger
zlog zerolog.Logger
typingIn id.RoomID
typingAt time.Time
@ -223,47 +239,47 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
return puppet.bridge.AS.Intent(puppet.MXID)
}
func (puppet *Puppet) UpdateAvatar(source *User, forcePortalSync bool) bool {
changed := source.updateAvatar(puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.log, puppet.DefaultIntent())
func (puppet *Puppet) UpdateAvatar(ctx context.Context, source *User, forcePortalSync bool) bool {
changed := source.updateAvatar(ctx, puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.DefaultIntent())
if !changed || puppet.Avatar == "unauthorized" {
if forcePortalSync {
go puppet.updatePortalAvatar()
go puppet.updatePortalAvatar(ctx)
}
return changed
}
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
err := puppet.DefaultIntent().SetAvatarURL(ctx, puppet.AvatarURL)
if err != nil {
puppet.log.Warnln("Failed to set avatar:", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar from puppet")
} else {
puppet.AvatarSet = true
}
go puppet.updatePortalAvatar()
go puppet.updatePortalAvatar(ctx)
return true
}
func (puppet *Puppet) UpdateName(contact types.ContactInfo, forcePortalSync bool) bool {
func (puppet *Puppet) UpdateName(ctx context.Context, contact types.ContactInfo, forcePortalSync bool) bool {
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact)
if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality {
oldName := puppet.Displayname
puppet.Displayname = newName
puppet.NameQuality = quality
puppet.NameSet = false
err := puppet.DefaultIntent().SetDisplayName(newName)
err := puppet.DefaultIntent().SetDisplayName(ctx, newName)
if err == nil {
puppet.log.Debugln("Updated name", oldName, "->", newName)
puppet.zlog.Debug().Str("old_name", oldName).Str("new_name", newName).Msg("Updated name")
puppet.NameSet = true
go puppet.updatePortalName()
go puppet.updatePortalName(ctx)
} else {
puppet.log.Warnln("Failed to set display name:", err)
puppet.zlog.Err(err).Msg("Failed to set displayname")
}
return true
} else if forcePortalSync {
go puppet.updatePortalName()
go puppet.updatePortalName(ctx)
}
return false
}
func (puppet *Puppet) UpdateContactInfo() bool {
func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool {
if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
return false
}
@ -281,9 +297,9 @@ func (puppet *Puppet) UpdateContactInfo() bool {
"com.beeper.bridge.service": "whatsapp",
"com.beeper.bridge.network": "whatsapp",
}
err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo)
err := puppet.DefaultIntent().BeeperUpdateProfile(ctx, contactInfo)
if err != nil {
puppet.log.Warnln("Failed to store custom contact info in profile:", err)
puppet.zlog.Err(err).Msg("Failed to store custom contact info in profile")
return false
} else {
puppet.ContactInfoSet = true
@ -300,7 +316,7 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
}
}
func (puppet *Puppet) updatePortalAvatar() {
func (puppet *Puppet) updatePortalAvatar(ctx context.Context) {
puppet.updatePortalMeta(func(portal *Portal) {
if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) {
return
@ -308,28 +324,31 @@ func (puppet *Puppet) updatePortalAvatar() {
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.AvatarSet = false
defer portal.Update(nil)
if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() {
portal.UpdateBridgeInfo()
portal.UpdateBridgeInfo(ctx)
} else if len(portal.MXID) > 0 {
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
_, err := portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, puppet.AvatarURL)
if err != nil {
portal.log.Warnln("Failed to set avatar:", err)
portal.zlog.Err(err).Msg("Failed to set avatar from puppet")
} else {
portal.AvatarSet = true
portal.UpdateBridgeInfo()
portal.UpdateBridgeInfo(ctx)
}
}
err := portal.Update(ctx)
if err != nil {
portal.zlog.Err(err).Msg("Failed to save portal after updating avatar from puppet")
}
})
}
func (puppet *Puppet) updatePortalName() {
func (puppet *Puppet) updatePortalName(ctx context.Context) {
puppet.updatePortalMeta(func(portal *Portal) {
portal.UpdateName(puppet.Displayname, types.EmptyJID, true)
portal.UpdateName(ctx, puppet.Displayname, types.EmptyJID, true)
})
}
func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
if puppet == nil {
return
}
@ -337,39 +356,67 @@ func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName
source.EnqueuePuppetResync(puppet)
return
}
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.SyncContact").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
contact, err := source.Client.Store.Contacts.GetContact(puppet.JID)
if err != nil {
puppet.log.Warnfln("Failed to get contact info through %s in SyncContact: %v (sync reason: %s)", source.MXID, reason)
log.Err(err).
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("Failed to get contact info through user in SyncContact")
} else if !contact.Found {
puppet.log.Warnfln("No contact info found through %s in SyncContact (sync reason: %s)", source.MXID, reason)
log.Warn().
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("No contact info found through user in SyncContact")
}
puppet.Sync(source, &contact, false, false)
puppet.syncInternal(ctx, source, &contact, false, false)
}
func (puppet *Puppet) Sync(source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
func (puppet *Puppet) Sync(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.Sync").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
puppet.syncInternal(ctx, source, contact, forceAvatarSync, forcePortalSync)
}
func (puppet *Puppet) syncInternal(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx)
puppet.syncLock.Lock()
defer puppet.syncLock.Unlock()
err := puppet.DefaultIntent().EnsureRegistered()
err := puppet.DefaultIntent().EnsureRegistered(ctx)
if err != nil {
puppet.log.Errorln("Failed to ensure registered:", err)
log.Err(err).Msg("Failed to ensure registered")
}
puppet.log.Debugfln("Syncing info through %s", source.JID)
log.Debug().Stringer("source_jid", source.JID).Msg("Syncing info through user")
update := false
if contact != nil {
if puppet.JID.User == source.JID.User {
contact.PushName = source.Client.Store.PushName
}
update = puppet.UpdateName(*contact, forcePortalSync) || update
update = puppet.UpdateName(ctx, *contact, forcePortalSync) || update
}
if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync {
update = puppet.UpdateAvatar(source, forcePortalSync) || update
update = puppet.UpdateAvatar(ctx, source, forcePortalSync) || update
}
update = puppet.UpdateContactInfo() || update
update = puppet.UpdateContactInfo(ctx) || update
if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) {
puppet.LastSync = time.Now()
puppet.Update()
err = puppet.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save puppet after sync")
}
}
}

View File

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -19,7 +19,6 @@ package main
import (
"bytes"
"context"
"encoding/json"
"image"
"net/http"
"net/url"
@ -27,33 +26,26 @@ import (
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/rs/zerolog"
"golang.org/x/net/idna"
"google.golang.org/protobuf/proto"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
)
type BeeperLinkPreview struct {
mautrix.RespPreviewURL
MatchedURL string `json:"matched_url"`
ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
}
func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*BeeperLinkPreview {
func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*event.BeeperLinkPreview {
if msg.GetMatchedText() == "" {
return []*BeeperLinkPreview{}
return []*event.BeeperLinkPreview{}
}
output := &BeeperLinkPreview{
output := &event.BeeperLinkPreview{
MatchedURL: msg.GetMatchedText(),
RespPreviewURL: mautrix.RespPreviewURL{
LinkPreview: event.LinkPreview{
CanonicalURL: msg.GetCanonicalUrl(),
Title: msg.GetTitle(),
Description: msg.GetDescription(),
@ -68,7 +60,7 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
var err error
thumbnailData, err = source.Client.DownloadThumbnail(msg)
if err != nil {
portal.log.Warnfln("Failed to download thumbnail for link preview: %v", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to download thumbnail for link preview")
}
}
if thumbnailData == nil && msg.JpegThumbnail != nil {
@ -93,9 +85,9 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
uploadMime = "application/octet-stream"
output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto}
}
resp, err := intent.UploadBytes(uploadData, uploadMime)
resp, err := intent.UploadBytes(ctx, uploadData, uploadMime)
if err != nil {
portal.log.Warnfln("Failed to reupload thumbnail for link preview: %v", err)
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload thumbnail for link preview")
} else {
if output.ImageEncryption != nil {
output.ImageEncryption.URL = resp.ContentURI.CUString()
@ -108,36 +100,37 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
output.Type = "video.other"
}
return []*BeeperLinkPreview{output}
return []*event.BeeperLinkPreview{output}
}
var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`)
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, evt *event.Event, dest *waProto.ExtendedTextMessage) bool {
var preview *BeeperLinkPreview
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, content *event.MessageEventContent, dest *waProto.ExtendedTextMessage) bool {
log := zerolog.Ctx(ctx)
var preview *event.BeeperLinkPreview
rawPreview := gjson.GetBytes(evt.Content.VeryRaw, `com\.beeper\.linkpreviews`)
if rawPreview.Exists() && rawPreview.IsArray() {
var previews []BeeperLinkPreview
if err := json.Unmarshal([]byte(rawPreview.Raw), &previews); err != nil || len(previews) == 0 {
if content.BeeperLinkPreviews != nil {
// Note: this check explicitly happens after checking for nil: empty arrays are treated as no previews,
// but omitting the field means the bridge may look for URLs in the message text.
if len(content.BeeperLinkPreviews) == 0 {
return false
}
// WhatsApp only supports a single preview.
preview = &previews[0]
preview = content.BeeperLinkPreviews[0]
} else if portal.bridge.Config.Bridge.URLPreviews {
if matchedURL := URLRegex.FindString(evt.Content.AsMessage().Body); len(matchedURL) == 0 {
if matchedURL := URLRegex.FindString(content.Body); len(matchedURL) == 0 {
return false
} else if parsed, err := url.Parse(matchedURL); err != nil {
return false
} else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
return false
} else if mxPreview, err := portal.MainIntent().GetURLPreview(parsed.String()); err != nil {
portal.log.Warnfln("Failed to fetch preview for %s: %v", matchedURL, err)
} else if mxPreview, err := portal.MainIntent().GetURLPreview(ctx, parsed.String()); err != nil {
log.Err(err).Str("url", matchedURL).Msg("Failed to fetch URL preview")
return false
} else {
preview = &BeeperLinkPreview{
RespPreviewURL: *mxPreview,
MatchedURL: matchedURL,
preview = &event.BeeperLinkPreview{
LinkPreview: *mxPreview,
MatchedURL: matchedURL,
}
}
}
@ -163,22 +156,22 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
imageMXC = preview.ImageEncryption.URL.ParseOrIgnore()
}
if !imageMXC.IsEmpty() {
data, err := portal.MainIntent().DownloadBytesContext(ctx, imageMXC)
data, err := portal.MainIntent().DownloadBytes(ctx, imageMXC)
if err != nil {
portal.log.Errorfln("Failed to download URL preview image %s in %s: %v", preview.ImageURL, evt.ID, err)
log.Err(err).Str("image_url", string(preview.ImageURL)).Msg("Failed to download URL preview image")
return true
}
if preview.ImageEncryption != nil {
err = preview.ImageEncryption.DecryptInPlace(data)
if err != nil {
portal.log.Errorfln("Failed to decrypt URL preview image in %s: %v", evt.ID, err)
log.Err(err).Msg("Failed to decrypt URL preview image")
return true
}
}
dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix())
uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail)
if err != nil {
portal.log.Errorfln("Failed to upload URL preview thumbnail in %s: %v", evt.ID, err)
log.Err(err).Msg("Failed to reupload URL preview thumbnail")
return true
}
dest.ThumbnailSha256 = uploadResp.FileSHA256
@ -188,7 +181,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
var width, height int
dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false)
if err != nil {
portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err)
log.Err(err).Msg("Failed to create JPEG thumbnail for URL preview")
}
if preview.ImageHeight > 0 && preview.ImageWidth > 0 {
dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth))

526
user.go

File diff suppressed because it is too large Load Diff