Compare commits

..

1 commit

Author SHA1 Message Date
Marco Antonio Alvarez d26e8dd953
Merge 38128830cc into f8a22aab06 2024-02-16 22:21:34 +01:00
37 changed files with 2929 additions and 3434 deletions

View file

@ -8,8 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go-version: ["1.21", "1.22"]
name: Lint ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }}
go-version: ["1.20", "1.21"]
steps:
- uses: actions/checkout@v4

View file

@ -15,6 +15,6 @@ repos:
- id: go-vet-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.3.1
rev: v0.2.2
hooks:
- id: zerolog-ban-msgf

View file

@ -1,9 +1,3 @@
# v0.10.6 (unreleased)
* Bumped minimum Go version to 1.21.
* Added 8-letter code pairing support to provisioning API.
* Added more bugs to fix later.
# v0.10.5 (2023-12-16)
* Added support for sending media to channels.

View file

@ -22,7 +22,7 @@ import (
"fmt"
"net/http"
"github.com/rs/zerolog"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
@ -30,7 +30,7 @@ type AnalyticsClient struct {
url string
key string
userID string
log zerolog.Logger
log log.Logger
client http.Client
}
@ -88,9 +88,9 @@ func (sc *AnalyticsClient) Track(userID id.UserID, event string, properties ...m
props["bridge"] = "whatsapp"
err := sc.trackSync(userID, event, props)
if err != nil {
sc.log.Err(err).Str("event", event).Msg("Error tracking event")
sc.log.Errorfln("Error tracking %s: %v", event, err)
} else {
sc.log.Debug().Str("event", event).Msg("Tracked event")
sc.log.Debugln("Tracked", event)
}
}()
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,21 +17,22 @@
package main
import (
"context"
"time"
"github.com/rs/zerolog"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database"
)
type BackfillQueue struct {
BackfillQuery *database.BackfillTaskQuery
BackfillQuery *database.BackfillQuery
reCheckChannels []chan bool
log log.Logger
}
func (bq *BackfillQueue) ReCheck() {
bq.log.Infofln("Sending re-checks to %d channels", len(bq.reCheckChannels))
for _, channel := range bq.reCheckChannels {
go func(c chan bool) {
c <- true
@ -39,19 +40,12 @@ func (bq *BackfillQueue) ReCheck() {
}
}
func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.BackfillTask {
func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill {
for {
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(ctx, userID, waitForBackfillTypes) {
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(userID, waitForBackfillTypes) {
// check for immediate when dealing with deferred
if backfill, err := bq.BackfillQuery.GetNext(ctx, userID, backfillTypes); err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get next backfill task")
} else if backfill != nil {
err = backfill.MarkDispatched(ctx)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Int("queue_id", backfill.QueueID).
Msg("Failed to mark backfill task as dispatched")
}
if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil {
backfill.MarkDispatched()
return backfill
}
}
@ -64,73 +58,38 @@ func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID,
}
func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) {
log := user.zlog.With().
Str("action", "backfill request loop").
Any("types", backfillTypes).
Logger()
ctx := log.WithContext(context.TODO())
reCheckChannel := make(chan bool)
user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel)
for {
req := user.BackfillQueue.GetNextBackfill(ctx, user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
log.Info().Any("backfill_request", req).Msg("Handling backfill request")
log := log.With().
Int("queue_id", req.QueueID).
Stringer("portal_jid", req.Portal.JID).
Logger()
ctx := log.WithContext(ctx)
req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
user.log.Infofln("Handling backfill request %s", req)
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, req.Portal)
if err != nil {
log.Err(err).Msg("Failed to get conversation data for backfill request")
continue
} else if conv == nil {
log.Debug().Msg("Couldn't find conversation data for backfill request")
err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after data was not found")
}
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, *req.Portal)
if conv == nil {
user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String())
req.MarkDone()
continue
}
portal := user.GetPortalByJID(conv.PortalKey.JID)
// Update the client store with basic chat settings.
if conv.MuteEndTime.After(time.Now()) {
err = user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
if err != nil {
log.Err(err).Msg("Failed to save muted until time from conversation data")
}
user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
}
if conv.Archived {
err = user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save archived state from conversation data")
}
user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
}
if conv.Pinned > 0 {
err = user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
if err != nil {
log.Err(err).Msg("Failed to save pinned state from conversation data")
}
user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
}
if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
log.Debug().
Uint32("old_time", portal.ExpirationTime).
Uint32("new_time", *conv.EphemeralExpiration).
Msg("Updating portal ephemeral expiration time")
portal.ExpirationTime = *conv.EphemeralExpiration
err = portal.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal after updating expiration time")
}
portal.Update(nil)
}
user.backfillInChunks(ctx, req, conv, portal)
err = req.MarkDone(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill request as done after backfilling")
}
user.backfillInChunks(req, conv, portal)
req.MarkDone()
}
}

View file

@ -93,7 +93,7 @@ func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Requ
remote = remote.Fill(user)
resp.RemoteStates[remote.RemoteID] = remote
}
user.zlog.Debug().Any("response_data", &resp).Msg("Responding bridge state in bridge status endpoint")
user.log.Debugfln("Responding bridge state in bridge status endpoint: %+v", resp)
jsonResponse(w, http.StatusOK, &resp)
if len(resp.RemoteStates) > 0 {
user.BridgeState.SetPrev(remote)

View file

@ -29,7 +29,6 @@ import (
"strings"
"time"
"github.com/rs/zerolog"
"github.com/skip2/go-qrcode"
"github.com/tidwall/gjson"
@ -38,6 +37,7 @@ import (
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/event"
@ -117,10 +117,7 @@ func fnSetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ce.User.MXID
err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting relay user")
}
ce.Portal.Update(nil)
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
}
}
@ -142,10 +139,7 @@ func fnUnsetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ""
err := ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after clearing relay user")
}
ce.Portal.Update(nil)
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
}
}
@ -252,7 +246,7 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to join group: %v", err)
return
}
ce.ZLog.Debug().Stringer("group_jid", jid).Msg("User successfully joined WhatsApp group with link")
ce.Log.Debugln("%s successfully joined group %s", ce.User.MXID, jid)
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
} else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) {
info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0])
@ -265,14 +259,14 @@ func fnJoin(ce *WrappedCommandEvent) {
ce.Reply("Failed to follow channel: %v", err)
return
}
ce.ZLog.Debug().Stringer("channel_jid", info.ID).Msg("User successfully followed WhatsApp channel with link")
ce.Log.Debugln("%s successfully followed channel %s", ce.User.MXID, info.ID)
ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID)
} else {
ce.Reply("That doesn't look like a WhatsApp invite link")
}
}
func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage, error) {
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
var data json.RawMessage
if evt.Type != event.EventEncrypted {
data = evt.Content.VeryRaw
@ -281,7 +275,7 @@ func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, err
}
decrypted, err := ce.Bridge.Crypto.Decrypt(ce.Ctx, evt)
decrypted, err := crypto.Decrypt(evt)
if err != nil {
return nil, err
}
@ -317,11 +311,11 @@ var cmdAccept = &commands.FullHandler{
func fnAccept(ce *WrappedCommandEvent) {
if len(ce.ReplyTo) == 0 {
ce.Reply("You must reply to a group invite message when using this command.")
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.Ctx, ce.RoomID, ce.ReplyTo); err != nil {
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to get reply target event to handle !wa accept command")
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.RoomID, ce.ReplyTo); err != nil {
ce.Log.Errorln("Failed to get event %s to handle !wa accept command: %v", ce.ReplyTo, err)
ce.Reply("Failed to get reply event")
} else if rawContent, err := tryDecryptEvent(ce, evt); err != nil {
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to decrypt reply target event to handle !wa accept command")
} else if rawContent, err := tryDecryptEvent(ce.Bridge.Crypto, evt); err != nil {
ce.Log.Errorln("Failed to decrypt event %s to handle !wa accept command: %v", ce.ReplyTo, err)
ce.Reply("Failed to decrypt reply event")
} else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil {
ce.Reply("That doesn't look like a group invite message.")
@ -350,16 +344,16 @@ func fnCreate(ce *WrappedCommandEvent) {
return
}
members, err := ce.Bot.JoinedMembers(ce.Ctx, ce.RoomID)
members, err := ce.Bot.JoinedMembers(ce.RoomID)
if err != nil {
ce.Reply("Failed to get room members: %v", err)
return
}
var roomNameEvent event.RoomNameEventContent
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateRoomName, "", &roomNameEvent)
err = ce.Bot.StateEvent(ce.RoomID, event.StateRoomName, "", &roomNameEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room name to create group")
ce.Log.Errorln("Failed to get room name to create group:", err)
ce.Reply("Failed to get room name")
return
} else if len(roomNameEvent.Name) == 0 {
@ -368,17 +362,15 @@ func fnCreate(ce *WrappedCommandEvent) {
}
var encryptionEvent event.EncryptionEventContent
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateEncryption, "", &encryptionEvent)
err = ce.Bot.StateEvent(ce.RoomID, event.StateEncryption, "", &encryptionEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room encryption status to create group")
ce.Reply("Failed to get room encryption status")
return
}
var createEvent event.CreateEventContent
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateCreate, "", &createEvent)
err = ce.Bot.StateEvent(ce.RoomID, event.StateCreate, "", &createEvent)
if err != nil && !errors.Is(err, mautrix.MNotFound) {
ce.ZLog.Err(err).Msg("Failed to get room create event to create group")
ce.Reply("Failed to get room create event")
return
}
@ -403,11 +395,7 @@ func fnCreate(ce *WrappedCommandEvent) {
// TODO check m.space.parent to create rooms directly in communities
messageID := ce.User.Client.GenerateMessageID()
ce.ZLog.Info().
Str("room_name", roomNameEvent.Name).
Any("participants", participants).
Str("create_key", messageID).
Msg("Creating WhatsApp group for Matrix room")
ce.Log.Infofln("Creating group for %s with name %s and participants %+v (create key: %s)", ce.RoomID, roomNameEvent.Name, participants, messageID)
ce.User.createKeyDedup = messageID
resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{
CreateKey: messageID,
@ -421,25 +409,21 @@ func fnCreate(ce *WrappedCommandEvent) {
ce.Reply("Failed to create group: %v", err)
return
}
ce.ZLog.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("group_jid", resp.JID.String())
})
portal := ce.User.GetPortalByJID(resp.JID)
portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock()
if len(portal.MXID) != 0 {
ce.ZLog.Warn().Msg("Detected race condition in room creation")
portal.log.Warnln("Detected race condition in room creation")
// TODO race condition, clean up the old room
}
portal.MXID = ce.RoomID
portal.updateLogger()
portal.Name = roomNameEvent.Name
portal.IsParent = resp.IsParent
portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1
if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default {
_, err = portal.MainIntent().SendStateEvent(ce.Ctx, portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
_, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil {
ce.ZLog.Err(err).Msg("Failed to enable encryption in room")
portal.log.Warnln("Failed to enable encryption in room:", err)
if errors.Is(err, mautrix.MForbidden) {
ce.Reply("I don't seem to have permission to enable encryption in this room.")
} else {
@ -449,11 +433,8 @@ func fnCreate(ce *WrappedCommandEvent) {
portal.Encrypted = true
}
err = portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after creating group")
}
portal.UpdateBridgeInfo(ce.Ctx)
portal.Update(nil)
portal.UpdateBridgeInfo()
ce.User.createKeyDedup = ""
ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)
@ -531,7 +512,7 @@ func fnLogin(ce *WrappedCommandEvent) {
}
}
if qrEventID != "" {
_, _ = ce.Bot.RedactEvent(ce.Ctx, ce.RoomID, qrEventID)
_, _ = ce.Bot.RedactEvent(ce.RoomID, qrEventID)
}
}
@ -548,9 +529,9 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
if len(prevEvent) != 0 {
content.SetEdit(prevEvent)
}
resp, err := ce.Bot.SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, &content)
resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to send edited QR code to user")
user.log.Errorln("Failed to send edited QR code to user:", err)
} else if len(prevEvent) == 0 {
prevEvent = resp.EventID
}
@ -560,16 +541,16 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) {
qrCode, err := qrcode.Encode(code, qrcode.Low, 256)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to encode QR code")
user.log.Errorln("Failed to encode QR code:", err)
ce.Reply("Failed to encode QR code: %v", err)
return id.ContentURI{}, false
}
bot := user.bridge.AS.BotClient()
resp, err := bot.UploadBytes(ce.Ctx, qrCode, "image/png")
resp, err := bot.UploadBytes(qrCode, "image/png")
if err != nil {
ce.ZLog.Err(err).Msg("Failed to upload QR code")
user.log.Errorln("Failed to upload QR code:", err)
ce.Reply("Failed to upload QR code: %v", err)
return id.ContentURI{}, false
}
@ -597,14 +578,14 @@ func fnLogout(ce *WrappedCommandEvent) {
puppet.ClearCustomMXID()
err := ce.User.Client.Logout()
if err != nil {
ce.ZLog.Err(err).Msg("Unknown error while logging out")
ce.User.log.Warnln("Error while logging out:", err)
ce.Reply("Unknown error while logging out: %v", err)
return
}
ce.User.Session = nil
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection()
ce.User.DeleteSession(ce.Ctx)
ce.User.DeleteSession()
ce.Reply("Logged out successfully.")
}
@ -639,13 +620,10 @@ func fnTogglePresence(ce *WrappedCommandEvent) {
if ce.User.IsLoggedIn() {
err := ce.User.Client.SendPresence(newPresence)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to send presence to WhatsApp")
ce.User.log.Warnln("Failed to set presence:", err)
}
}
err := customPuppet.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save puppet after toggling presence")
}
customPuppet.Update()
}
var cmdDeleteSession = &commands.FullHandler{
@ -664,7 +642,7 @@ func fnDeleteSession(ce *WrappedCommandEvent) {
}
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
ce.User.DeleteConnection()
ce.User.DeleteSession(ce.Ctx)
ce.User.DeleteSession()
ce.Reply("Session information purged")
}
@ -738,19 +716,19 @@ func fnPing(ce *WrappedCommandEvent) {
}
}
func canDeletePortal(ce *WrappedCommandEvent, portal *Portal) bool {
func canDeletePortal(portal *Portal, userID id.UserID) bool {
if len(portal.MXID) == 0 {
return false
}
members, err := portal.MainIntent().JoinedMembers(ce.Ctx, portal.MXID)
members, err := portal.MainIntent().JoinedMembers(portal.MXID)
if err != nil {
ce.ZLog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get joined members to check if portal can be deleted by user")
portal.log.Errorfln("Failed to get joined members to check if portal can be deleted by %s: %v", userID, err)
return false
}
for otherUser := range members.Joined {
_, isPuppet := portal.bridge.ParsePuppetMXID(otherUser)
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == ce.User.MXID {
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == userID {
continue
}
user := portal.bridge.GetUserByMXID(otherUser)
@ -772,14 +750,14 @@ var cmdDeletePortal = &commands.FullHandler{
}
func fnDeletePortal(ce *WrappedCommandEvent) {
if !ce.User.Admin && !canDeletePortal(ce, ce.Portal) {
if !ce.User.Admin && !canDeletePortal(ce.Portal, ce.User.MXID) {
ce.Reply("Only bridge admins can delete portals with other Matrix users")
return
}
ce.ZLog.Info().Msg("User requested deletion of current portal")
ce.Portal.Delete(ce.Ctx)
ce.Portal.Cleanup(ce.Ctx, false)
ce.Portal.log.Infoln(ce.User.MXID, "requested deletion of portal.")
ce.Portal.Delete()
ce.Portal.Cleanup(false)
}
var cmdDeleteAllPortals = &commands.FullHandler{
@ -800,7 +778,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
} else {
portalsToDelete = portals[:0]
for _, portal := range portals {
if canDeletePortal(ce, portal) {
if canDeletePortal(portal, ce.User.MXID) {
portalsToDelete = append(portalsToDelete, portal)
}
}
@ -812,7 +790,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
leave := func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = portal.MainIntent().KickUser(ce.Ctx, portal.MXID, &mautrix.ReqKickUser{
_, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{
Reason: "Deleting portal",
UserID: ce.User.MXID,
})
@ -823,21 +801,21 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
intent := customPuppet.CustomIntent()
leave = func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = intent.LeaveRoom(ce.Ctx, portal.MXID)
_, _ = intent.ForgetRoom(ce.Ctx, portal.MXID)
_, _ = intent.LeaveRoom(portal.MXID)
_, _ = intent.ForgetRoom(portal.MXID)
}
}
}
ce.Reply("Found %d portals, deleting...", len(portalsToDelete))
for _, portal := range portalsToDelete {
portal.Delete(ce.Ctx)
portal.Delete()
leave(portal)
}
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
go func() {
for _, portal := range portalsToDelete {
portal.Cleanup(ce.Ctx, false)
portal.Cleanup(false)
}
ce.Reply("Finished background cleanup of deleted portal rooms.")
}()
@ -904,7 +882,7 @@ func fnList(ce *WrappedCommandEvent) {
}
var err error
page := 1
maxPerPage := 100
max := 100
if len(ce.Args) > 1 {
page, err = strconv.Atoi(ce.Args[1])
if err != nil || page <= 0 {
@ -913,11 +891,11 @@ func fnList(ce *WrappedCommandEvent) {
}
}
if len(ce.Args) > 2 {
maxPerPage, err = strconv.Atoi(ce.Args[2])
if err != nil || maxPerPage <= 0 {
max, err = strconv.Atoi(ce.Args[2])
if err != nil || max <= 0 {
ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2])
return
} else if maxPerPage > 400 {
} else if max > 400 {
ce.Reply("Warning: a high number of items per page may fail to send a reply")
}
}
@ -946,8 +924,8 @@ func fnList(ce *WrappedCommandEvent) {
ce.Reply("No %s found", strings.ToLower(typeName))
return
}
pages := int(math.Ceil(float64(len(result)) / float64(maxPerPage)))
if (page-1)*maxPerPage >= len(result) {
pages := int(math.Ceil(float64(len(result)) / float64(max)))
if (page-1)*max >= len(result) {
if pages == 1 {
ce.Reply("There is only 1 page of %s", strings.ToLower(typeName))
} else {
@ -955,11 +933,11 @@ func fnList(ce *WrappedCommandEvent) {
}
return
}
lastIndex := page * maxPerPage
lastIndex := page * max
if lastIndex > len(result) {
lastIndex = len(result)
}
result = result[(page-1)*maxPerPage : lastIndex]
result = result[(page-1)*max : lastIndex]
ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n"))
}
@ -1058,13 +1036,13 @@ func fnOpen(ce *WrappedCommandEvent) {
}
jid = newsletterMetadata.ID
}
ce.ZLog.Debug().Stringer("chat_jid", jid).Msg("Importing chat for user")
ce.Log.Debugln("Importing", jid, "for", ce.User.MXID)
portal := ce.User.GetPortalByJID(jid)
if len(portal.MXID) > 0 {
portal.UpdateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata)
portal.UpdateMatrixRoom(ce.User, groupInfo, newsletterMetadata)
ce.Reply("Portal room synced.")
} else {
err = portal.CreateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata, true, true)
err = portal.CreateMatrixRoom(ce.User, groupInfo, newsletterMetadata, true, true)
if err != nil {
ce.Reply("Failed to create room: %v", err)
} else {
@ -1107,7 +1085,7 @@ func fnPM(ce *WrappedCommandEvent) {
return
}
portal, puppet, justCreated, err := user.StartPM(ce.Ctx, targetUser.JID, "manual PM command")
portal, puppet, justCreated, err := user.StartPM(targetUser.JID, "manual PM command")
if err != nil {
ce.Reply("Failed to create portal room: %v", err)
} else if !justCreated {
@ -1176,16 +1154,11 @@ func fnSync(ce *WrappedCommandEvent) {
ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge")
return
}
keys, err := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.Ctx, ce.User.JID)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to get list of private chats not in space")
ce.Reply("Failed to get list of private chats not in space")
return
}
keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID)
count := 0
for _, key := range keys {
portal := ce.Bridge.GetPortalByJID(key)
portal.addToPersonalSpace(ce.Ctx, ce.User)
portal.addToPersonalSpace(ce.User)
count++
}
plural := "s"
@ -1235,9 +1208,6 @@ func fnDisappearingTimer(ce *WrappedCommandEvent) {
ce.Portal.ExpirationTime = prevExpirationTime
return
}
err = ce.Portal.Update(ce.Ctx)
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save portal after setting disappearing timer")
}
ce.Portal.Update(nil)
ce.React("✅")
}

View file

@ -17,9 +17,6 @@
package main
import (
"context"
"fmt"
"maunium.net/go/mautrix/id"
)
@ -27,11 +24,8 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error
puppet.CustomMXID = mxid
puppet.AccessToken = accessToken
puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence
err := puppet.Update(context.TODO())
if err != nil {
return fmt.Errorf("failed to save access token: %w", err)
}
err = puppet.StartCustomMXID(false)
puppet.Update()
err := puppet.StartCustomMXID(false)
if err != nil {
return err
}
@ -51,15 +45,12 @@ func (puppet *Puppet) ClearCustomMXID() {
puppet.customIntent = nil
puppet.customUser = nil
if save {
err := puppet.Update(context.TODO())
if err != nil {
puppet.zlog.Err(err).Msg("Failed to clear custom MXID")
}
puppet.Update()
}
}
func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
if err != nil {
puppet.ClearCustomMXID()
return err
@ -69,11 +60,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
puppet.bridge.puppetsLock.Unlock()
if puppet.AccessToken != newAccessToken {
puppet.AccessToken = newAccessToken
err = puppet.Update(context.TODO())
puppet.Update()
}
puppet.customIntent = newIntent
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
return err
return nil
}
func (user *User) tryAutomaticDoublePuppeting() {

340
database/backfill.go Normal file
View file

@ -0,0 +1,340 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillQuery struct {
db *Database
log log.Logger
backfillQueryLock sync.Mutex
}
func (bq *BackfillQuery) New() *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
Portal: &PortalKey{},
}
}
func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
return &Backfill{
db: bq.db,
log: bq.log,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillQuery = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightQuery = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
)
// GetNext returns the next backfill to perform
func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
var types []string
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID, time.Now().Add(-15*time.Minute))
if err != nil || rows == nil {
bq.log.Errorfln("Failed to query next backfill queue job: %v", err)
return
}
defer rows.Close()
if rows.Next() {
backfill = bq.New().Scan(rows)
}
return
}
func (bq *BackfillQuery) HasUnstartedOrInFlightOfType(userID id.UserID, backfillTypes []BackfillType) bool {
if len(backfillTypes) == 0 {
return false
}
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
types := []string{}
for _, backfillType := range backfillTypes {
types = append(types, strconv.Itoa(int(backfillType)))
}
rows, err := bq.db.Query(fmt.Sprintf(getUnstartedOrInFlightQuery, strings.Join(types, ",")), userID)
if err != nil || rows == nil {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
bq.log.Warnfln("Failed to query backfill queue jobs: %v", err)
}
// No rows means that there are no unstarted or in flight backfill
// requests.
return false
}
defer rows.Close()
return rows.Next()
}
func (bq *BackfillQuery) DeleteAll(userID id.UserID) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s: %v", userID, err)
}
}
func (bq *BackfillQuery) DeleteAllForPortal(userID id.UserID, portalKey PortalKey) {
bq.backfillQueryLock.Lock()
defer bq.backfillQueryLock.Unlock()
_, err := bq.db.Exec(`
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`, userID, portalKey.JID, portalKey.Receiver)
if err != nil {
bq.log.Warnfln("Failed to delete backfill queue items for %s/%s: %v", userID, portalKey.JID, err)
}
}
type Backfill struct {
db *Database
log log.Logger
// Fields
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal *PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *Backfill) String() string {
return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &maxTotalEvents, &batchDelay)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b
}
func (b *Backfill) Insert() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
rows, err := b.db.Query(`
INSERT INTO backfill_queue
(user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt)
defer rows.Close()
if err != nil || !rows.Next() {
b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
return
}
err = rows.Scan(&b.QueueID)
if err != nil {
b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
}
}
func (b *Backfill) MarkDispatched() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err)
}
}
func (b *Backfill) MarkDone() {
b.db.Backfill.backfillQueryLock.Lock()
defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
b.log.Errorfln("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?")
return
}
_, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
if err != nil {
b.log.Warnfln("Failed to mark %s/%s as complete: %v", b.BackfillType, b.Priority, err)
}
}
func (bq *BackfillQuery) NewBackfillState(userID id.UserID, portalKey *PortalKey) *BackfillState {
return &BackfillState{
db: bq.db,
log: bq.log,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillState = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
)
type BackfillState struct {
db *Database
log log.Logger
// Fields
UserID id.UserID
Portal *PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
b.log.Errorln("Database scan failed:", err)
}
return nil
}
return b
}
func (b *BackfillState) Upsert() {
_, err := b.db.Exec(`
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts`,
b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp)
if err != nil {
b.log.Warnfln("Failed to insert backfill state for %s: %v", b.Portal.JID, err)
}
}
func (b *BackfillState) SetProcessingBatch(processing bool) {
b.ProcessingBatch = processing
b.Upsert()
}
func (bq *BackfillQuery) GetBackfillState(userID id.UserID, portalKey *PortalKey) (backfillState *BackfillState) {
rows, err := bq.db.Query(getBackfillState, userID, portalKey.JID, portalKey.Receiver)
if err != nil || rows == nil {
bq.log.Error(err)
return
}
defer rows.Close()
if rows.Next() {
backfillState = bq.NewBackfillState(userID, portalKey).Scan(rows)
}
return
}

View file

@ -1,253 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillType int
const (
BackfillImmediate BackfillType = 0
BackfillForward BackfillType = 100
BackfillDeferred BackfillType = 200
)
func (bt BackfillType) String() string {
switch bt {
case BackfillImmediate:
return "IMMEDIATE"
case BackfillForward:
return "FORWARD"
case BackfillDeferred:
return "DEFERRED"
}
return "UNKNOWN"
}
type BackfillTaskQuery struct {
*dbutil.QueryHelper[*BackfillTask]
//backfillQueryLock sync.Mutex
}
func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask {
return &BackfillTask{qh: qh}
}
func (bq *BackfillTaskQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *BackfillTask {
return &BackfillTask{
qh: bq.QueryHelper,
UserID: userID,
BackfillType: backfillType,
Priority: priority,
Portal: portal,
TimeStart: timeStart,
MaxBatchEvents: maxBatchEvents,
MaxTotalEvents: maxTotalEvents,
BatchDelay: batchDelay,
}
}
const (
getNextBackfillTaskQueryTemplate = `
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (
dispatch_time IS NULL
OR (
dispatch_time < $2
AND completed_at IS NULL
)
)
ORDER BY type, priority, queue_id
LIMIT 1
`
getUnstartedOrInFlightBackfillTaskQueryTemplate = `
SELECT 1
FROM backfill_queue
WHERE user_mxid=$1
AND type IN (%s)
AND (dispatch_time IS NULL OR completed_at IS NULL)
LIMIT 1
`
deleteBackfillQueueForUserQuery = "DELETE FROM backfill_queue WHERE user_mxid=$1"
deleteBackfillQueueForPortalQuery = `
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
insertBackfillTaskQuery = `
INSERT INTO backfill_queue (
user_mxid, type, priority, portal_jid, portal_receiver, time_start,
max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING queue_id
`
markBackfillTaskDispatchedQuery = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
markBackfillTaskDoneQuery = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
)
func typesToString(backfillTypes []BackfillType) string {
types := make([]string, len(backfillTypes))
for i, backfillType := range backfillTypes {
types[i] = strconv.Itoa(int(backfillType))
}
return strings.Join(types, ",")
}
// GetNext returns the next backfill to perform
func (bq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (*BackfillTask, error) {
if len(backfillTypes) == 0 {
return nil, nil
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getNextBackfillTaskQueryTemplate, typesToString(backfillTypes))
return bq.QueryOne(ctx, query, userID, time.Now().Add(-15*time.Minute))
}
func (bq *BackfillTaskQuery) HasUnstartedOrInFlightOfType(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (has bool) {
if len(backfillTypes) == 0 {
return false
}
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
query := fmt.Sprintf(getUnstartedOrInFlightBackfillTaskQueryTemplate, typesToString(backfillTypes))
err := bq.GetDB().QueryRow(ctx, query, userID).Scan(&has)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).Msg("Failed to check if backfill queue has jobs")
}
return
}
func (bq *BackfillTaskQuery) DeleteAll(ctx context.Context, userID id.UserID) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForUserQuery, userID)
}
func (bq *BackfillTaskQuery) DeleteAllForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
//bq.backfillQueryLock.Lock()
//defer bq.backfillQueryLock.Unlock()
return bq.Exec(ctx, deleteBackfillQueueForPortalQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillTask struct {
qh *dbutil.QueryHelper[*BackfillTask]
QueueID int
UserID id.UserID
BackfillType BackfillType
Priority int
Portal PortalKey
TimeStart *time.Time
MaxBatchEvents int
MaxTotalEvents int
BatchDelay int
DispatchTime *time.Time
CompletedAt *time.Time
}
func (b *BackfillTask) MarshalZerologObject(evt *zerolog.Event) {
evt.Int("queue_id", b.QueueID).
Stringer("user_id", b.UserID).
Stringer("backfill_type", b.BackfillType).
Int("priority", b.Priority).
Stringer("portal_jid", b.Portal.JID).
Any("time_start", b.TimeStart).
Int("max_batch_events", b.MaxBatchEvents).
Int("max_total_events", b.MaxTotalEvents).
Int("batch_delay", b.BatchDelay).
Any("dispatch_time", b.DispatchTime).
Any("completed_at", b.CompletedAt)
}
func (b *BackfillTask) String() string {
return fmt.Sprintf(
"BackfillTask{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
)
}
func (b *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) {
var maxTotalEvents, batchDelay sql.NullInt32
err := row.Scan(
&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart,
&b.MaxBatchEvents, &maxTotalEvents, &batchDelay,
)
if err != nil {
return nil, err
}
b.MaxTotalEvents = int(maxTotalEvents.Int32)
b.BatchDelay = int(batchDelay.Int32)
return b, nil
}
func (b *BackfillTask) sqlVariables() []any {
return []any{
b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart,
b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt,
}
}
func (b *BackfillTask) Insert(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
return b.qh.GetDB().QueryRow(ctx, insertBackfillTaskQuery, b.sqlVariables()...).Scan(&b.QueueID)
}
func (b *BackfillTask) MarkDispatched(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDispatchedQuery, time.Now(), b.QueueID)
}
func (b *BackfillTask) MarkDone(ctx context.Context) error {
//b.db.Backfill.backfillQueryLock.Lock()
//defer b.db.Backfill.backfillQueryLock.Unlock()
if b.QueueID == 0 {
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
}
return b.qh.Exec(ctx, markBackfillTaskDoneQuery, time.Now(), b.QueueID)
}

View file

@ -1,94 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
type BackfillStateQuery struct {
*dbutil.QueryHelper[*BackfillState]
}
func newBackfillState(qh *dbutil.QueryHelper[*BackfillState]) *BackfillState {
return &BackfillState{qh: qh}
}
func (bq *BackfillStateQuery) NewBackfillState(userID id.UserID, portalKey PortalKey) *BackfillState {
return &BackfillState{
qh: bq.QueryHelper,
UserID: userID,
Portal: portalKey,
}
}
const (
getBackfillStateQuery = `
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
FROM backfill_state
WHERE user_mxid=$1
AND portal_jid=$2
AND portal_receiver=$3
`
upsertBackfillStateQuery = `
INSERT INTO backfill_state
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
DO UPDATE SET
processing_batch=EXCLUDED.processing_batch,
backfill_complete=EXCLUDED.backfill_complete,
first_expected_ts=EXCLUDED.first_expected_ts
`
)
func (bq *BackfillStateQuery) GetBackfillState(ctx context.Context, userID id.UserID, portalKey PortalKey) (*BackfillState, error) {
return bq.QueryOne(ctx, getBackfillStateQuery, userID, portalKey.JID, portalKey.Receiver)
}
type BackfillState struct {
qh *dbutil.QueryHelper[*BackfillState]
UserID id.UserID
Portal PortalKey
ProcessingBatch bool
BackfillComplete bool
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row dbutil.Scannable) (*BackfillState, error) {
return dbutil.ValueOrErr(b, row.Scan(
&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp,
))
}
func (b *BackfillState) sqlVariables() []any {
return []any{b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp}
}
func (b *BackfillState) Upsert(ctx context.Context) error {
return b.qh.Exec(ctx, upsertBackfillStateQuery, b.sqlVariables()...)
}
func (b *BackfillState) SetProcessingBatch(ctx context.Context, processing bool) error {
b.ProcessingBatch = processing
return b.Upsert(ctx)
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -26,6 +26,7 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/database/upgrades"
)
@ -44,28 +45,51 @@ type Database struct {
Reaction *ReactionQuery
DisappearingMessage *DisappearingMessageQuery
BackfillQueue *BackfillTaskQuery
BackfillState *BackfillStateQuery
Backfill *BackfillQuery
HistorySync *HistorySyncQuery
MediaBackfillRequest *MediaBackfillRequestQuery
}
func New(db *dbutil.Database) *Database {
func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
db := &Database{Database: baseDB}
db.UpgradeTable = upgrades.Table
return &Database{
Database: db,
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)},
BackfillQueue: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)},
BackfillState: &BackfillStateQuery{dbutil.MakeQueryHelper(db, newBackfillState)},
HistorySync: &HistorySyncQuery{dbutil.MakeQueryHelper(db, newHistorySyncConversation)},
MediaBackfillRequest: &MediaBackfillRequestQuery{dbutil.MakeQueryHelper(db, newMediaBackfillRequest)},
db.User = &UserQuery{
db: db,
log: log.Sub("User"),
}
db.Portal = &PortalQuery{
db: db,
log: log.Sub("Portal"),
}
db.Puppet = &PuppetQuery{
db: db,
log: log.Sub("Puppet"),
}
db.Message = &MessageQuery{
db: db,
log: log.Sub("Message"),
}
db.Reaction = &ReactionQuery{
db: db,
log: log.Sub("Reaction"),
}
db.DisappearingMessage = &DisappearingMessageQuery{
db: db,
log: log.Sub("DisappearingMessage"),
}
db.Backfill = &BackfillQuery{
db: db,
log: log.Sub("Backfill"),
}
db.HistorySync = &HistorySyncQuery{
db: db,
log: log.Sub("HistorySync"),
}
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
db: db,
log: log.Sub("MediaBackfillRequest"),
}
return db
}
func isRetryableError(err error) bool {

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,29 +17,32 @@
package database
import (
"context"
"database/sql"
"errors"
"time"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type DisappearingMessageQuery struct {
*dbutil.QueryHelper[*DisappearingMessage]
db *Database
log log.Logger
}
func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
func (dmq *DisappearingMessageQuery) New() *DisappearingMessage {
return &DisappearingMessage{
qh: qh,
db: dmq.db,
log: dmq.log,
}
}
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
dm := &DisappearingMessage{
qh: dmq.QueryHelper,
db: dmq.db,
log: dmq.log,
RoomID: roomID,
EventID: eventID,
ExpireIn: expireIn,
@ -52,17 +55,22 @@ const (
getAllScheduledDisappearingMessagesQuery = `
SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1
`
insertDisappearingMessageQuery = `INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`
updateDisappearingMessageExpiryQuery = "UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3"
deleteDisappearingMessageQuery = "DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2"
)
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(duration time.Duration) (messages []*DisappearingMessage) {
rows, err := dmq.db.Query(getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, dmq.New().Scan(rows))
}
return
}
type DisappearingMessage struct {
qh *dbutil.QueryHelper[*DisappearingMessage]
db *Database
log log.Logger
RoomID id.RoomID
EventID id.EventID
@ -70,33 +78,50 @@ type DisappearingMessage struct {
ExpireAt time.Time
}
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
var expireIn int64
var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
if err != nil {
return nil, err
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
}
return nil
}
msg.ExpireIn = time.Duration(expireIn) * time.Millisecond
if expireAt.Valid {
msg.ExpireAt = time.UnixMilli(expireAt.Int64)
}
return msg, nil
return msg
}
func (msg *DisappearingMessage) sqlVariables() []any {
return []any{msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), dbutil.UnixMilliPtr(msg.ExpireAt)}
func (msg *DisappearingMessage) Insert(txn dbutil.Execable) {
if txn == nil {
txn = msg.db
}
var expireAt sql.NullInt64
if !msg.ExpireAt.IsZero() {
expireAt.Valid = true
expireAt.Int64 = msg.ExpireAt.UnixMilli()
}
_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt)
if err != nil {
msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)
}
}
func (msg *DisappearingMessage) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
}
func (msg *DisappearingMessage) StartTimer(ctx context.Context) error {
func (msg *DisappearingMessage) StartTimer() {
msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second)
return msg.qh.Exec(ctx, updateDisappearingMessageExpiryQuery, msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
_, err := msg.db.Exec("UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err)
}
}
func (msg *DisappearingMessage) Delete(ctx context.Context) error {
return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.RoomID, msg.EventID)
func (msg *DisappearingMessage) Delete() {
_, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2", msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,7 +17,8 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
@ -25,19 +26,23 @@ import (
"go.mau.fi/util/dbutil"
waProto "go.mau.fi/whatsmeow/binary/proto"
"google.golang.org/protobuf/proto"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type HistorySyncQuery struct {
*dbutil.QueryHelper[*HistorySyncConversation]
db *Database
log log.Logger
}
type HistorySyncConversation struct {
qh *dbutil.QueryHelper[*HistorySyncConversation]
db *Database
log log.Logger
UserID id.UserID
ConversationID string
PortalKey PortalKey
PortalKey *PortalKey
LastMessageTimestamp time.Time
MuteEndTime time.Time
Archived bool
@ -49,16 +54,18 @@ type HistorySyncConversation struct {
UnreadCount uint32
}
func newHistorySyncConversation(qh *dbutil.QueryHelper[*HistorySyncConversation]) *HistorySyncConversation {
func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation {
return &HistorySyncConversation{
qh: qh,
db: hsq.db,
log: hsq.log,
PortalKey: &PortalKey{},
}
}
func (hsq *HistorySyncQuery) NewConversationWithValues(
userID id.UserID,
conversationID string,
portalKey PortalKey,
portalKey *PortalKey,
lastMessageTimestamp,
muteEndTime uint64,
archived bool,
@ -67,10 +74,10 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType,
ephemeralExpiration *uint32,
markedAsUnread bool,
unreadCount uint32,
) *HistorySyncConversation {
unreadCount uint32) *HistorySyncConversation {
return &HistorySyncConversation{
qh: hsq.QueryHelper,
db: hsq.db,
log: hsq.log,
UserID: userID,
ConversationID: conversationID,
PortalKey: portalKey,
@ -87,17 +94,6 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
}
const (
upsertHistorySyncConversationQuery = `
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`
getNMostRecentConversations = `
SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
FROM history_sync_conversation
@ -112,19 +108,24 @@ const (
AND portal_jid=$2
AND portal_receiver=$3
`
deleteAllConversationsQuery = "DELETE FROM history_sync_conversation WHERE user_mxid=$1"
deleteHistorySyncConversationQuery = `
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`
)
func (hsc *HistorySyncConversation) sqlVariables() []any {
return []any{
func (hsc *HistorySyncConversation) Upsert() {
_, err := hsc.db.Exec(`
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (user_mxid, conversation_id)
DO UPDATE SET
last_message_timestamp=CASE
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
ELSE history_sync_conversation.last_message_timestamp
END,
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
`,
hsc.UserID,
hsc.ConversationID,
hsc.PortalKey.JID,
hsc.PortalKey.Receiver,
hsc.PortalKey.JID.String(),
hsc.PortalKey.Receiver.String(),
hsc.LastMessageTimestamp,
hsc.Archived,
hsc.Pinned,
@ -133,16 +134,14 @@ func (hsc *HistorySyncConversation) sqlVariables() []any {
hsc.EndOfHistoryTransferType,
hsc.EphemeralExpiration,
hsc.MarkedAsUnread,
hsc.UnreadCount,
hsc.UnreadCount)
if err != nil {
hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
}
}
func (hsc *HistorySyncConversation) Upsert(ctx context.Context) error {
return hsc.qh.Exec(ctx, upsertHistorySyncConversationQuery, hsc.sqlVariables()...)
}
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConversation, error) {
return dbutil.ValueOrErr(hsc, row.Scan(
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
err := row.Scan(
&hsc.UserID,
&hsc.ConversationID,
&hsc.PortalKey.JID,
@ -155,59 +154,69 @@ func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConv
&hsc.EndOfHistoryTransferType,
&hsc.EphemeralExpiration,
&hsc.MarkedAsUnread,
&hsc.UnreadCount,
))
&hsc.UnreadCount)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
hsc.log.Errorln("Database scan failed:", err)
}
return nil
}
return hsc
}
func (hsq *HistorySyncQuery) GetRecentConversations(ctx context.Context, userID id.UserID, n int) ([]*HistorySyncConversation, error) {
func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
nPtr := &n
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
if n < 0 && hsq.GetDB().Dialect == dbutil.Postgres {
if n < 0 && hsq.db.Dialect == dbutil.Postgres {
nPtr = nil
}
return hsq.QueryMany(ctx, getNMostRecentConversations, userID, nPtr)
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
conversations = append(conversations, hsq.NewConversation().Scan(rows))
}
return
}
func (hsq *HistorySyncQuery) GetConversation(ctx context.Context, userID id.UserID, portalKey PortalKey) (*HistorySyncConversation, error) {
return hsq.QueryOne(ctx, getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) {
rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
if rows.Next() {
conversation = hsq.NewConversation().Scan(rows)
}
return
}
func (hsq *HistorySyncQuery) DeleteAllConversations(ctx context.Context, userID id.UserID) error {
return hsq.Exec(ctx, deleteAllConversationsQuery, userID)
func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) {
_, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err)
}
}
const (
insertHistorySyncMessageQuery = `
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`
getHistorySyncMessagesBetweenQueryTemplate = `
getMessagesBetween = `
SELECT data FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
%s
ORDER BY timestamp DESC
%s
`
deleteHistorySyncMessagesBetweenExclusiveQuery = `
deleteMessagesBetweenExclusive = `
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4
`
deleteAllHistorySyncMessagesQuery = "DELETE FROM history_sync_message WHERE user_mxid=$1"
deleteHistorySyncMessagesForPortalQuery = `
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`
conversationHasHistorySyncMessagesQuery = `
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`
)
type HistorySyncMessage struct {
hsq *HistorySyncQuery
db *Database
log log.Logger
UserID id.UserID
ConversationID string
@ -222,8 +231,8 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
return nil, err
}
return &HistorySyncMessage{
hsq: hsq,
db: hsq.db,
log: hsq.log,
UserID: userID,
ConversationID: conversationID,
MessageID: messageID,
@ -232,27 +241,18 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
}, nil
}
func (hsm *HistorySyncMessage) Insert(ctx context.Context) error {
return hsm.hsq.Exec(ctx, insertHistorySyncMessageQuery, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
func (hsm *HistorySyncMessage) Insert() error {
_, err := hsm.db.Exec(`
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
return err
}
func scanWebMessageInfo(rows dbutil.Scannable) (*waProto.WebMessageInfo, error) {
var msgData []byte
err := rows.Scan(&msgData)
if err != nil {
return nil, err
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
}
return historySyncMsg.GetMessage(), nil
}
func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) ([]*waProto.WebMessageInfo, error) {
func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
whereClauses := ""
args := []any{userID, conversationID}
args := []interface{}{userID, conversationID}
argNum := 3
if startTime != nil {
whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
@ -268,35 +268,80 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.U
if limit > 0 {
limitClause = fmt.Sprintf("LIMIT %d", limit)
}
query := fmt.Sprintf(getHistorySyncMessagesBetweenQueryTemplate, whereClauses, limitClause)
return dbutil.ConvertRowFn[*waProto.WebMessageInfo](scanWebMessageInfo).
NewRowIter(hsq.GetDB().Query(ctx, query, args...)).
AsList()
rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...)
defer rows.Close()
if err != nil || rows == nil {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
hsq.log.Warnfln("Failed to query messages between range: %v", err)
}
return nil
}
var msgData []byte
for rows.Next() {
err = rows.Scan(&msgData)
if err != nil {
hsq.log.Errorfln("Database scan failed: %v", err)
continue
}
var historySyncMsg waProto.HistorySyncMsg
err = proto.Unmarshal(msgData, &historySyncMsg)
if err != nil {
hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err)
continue
}
messages = append(messages, historySyncMsg.Message)
}
return
}
func (hsq *HistorySyncQuery) DeleteMessages(ctx context.Context, userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
newest := messages[0]
beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0)
oldest := messages[len(messages)-1]
afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0)
return hsq.Exec(ctx, deleteHistorySyncMessagesBetweenExclusiveQuery, userID, conversationID, beforeTS, afterTS)
_, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS)
return err
}
func (hsq *HistorySyncQuery) DeleteAllMessages(ctx context.Context, userID id.UserID) error {
return hsq.Exec(ctx, deleteAllHistorySyncMessagesQuery, userID)
func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) {
_, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err)
}
}
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
return hsq.Exec(ctx, deleteHistorySyncMessagesForPortalQuery, userID, portalKey.JID)
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) {
_, err := hsq.db.Exec(`
DELETE FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, portalKey.JID)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err)
}
}
func (hsq *HistorySyncQuery) ConversationHasMessages(ctx context.Context, userID id.UserID, portalKey PortalKey) (exists bool, err error) {
err = hsq.GetDB().QueryRow(ctx, conversationHasHistorySyncMessagesQuery, userID, portalKey.JID).Scan(&exists)
func (hsq *HistorySyncQuery) ConversationHasMessages(userID id.UserID, portalKey PortalKey) (exists bool) {
err := hsq.db.QueryRow(`
SELECT EXISTS(
SELECT 1 FROM history_sync_message
WHERE user_mxid=$1 AND conversation_id=$2
)
`, userID, portalKey.JID).Scan(&exists)
if err != nil {
hsq.log.Warnfln("Failed to check if any messages are stored for %s/%s: %v", userID, portalKey.JID, err)
}
return
}
func (hsq *HistorySyncQuery) DeleteConversation(ctx context.Context, userID id.UserID, jid string) error {
func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) {
// This will also clear history_sync_message as there's a foreign key constraint
return hsq.Exec(ctx, deleteHistorySyncConversationQuery, userID, jid)
_, err := hsq.db.Exec(`
DELETE FROM history_sync_conversation
WHERE user_mxid=$1 AND conversation_id=$2
`, userID, jid)
if err != nil {
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,12 +17,14 @@
package database
import (
"context"
"database/sql"
"errors"
_ "github.com/mattn/go-sqlite3"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type MediaBackfillRequestStatus int
@ -34,46 +36,34 @@ const (
)
type MediaBackfillRequestQuery struct {
*dbutil.QueryHelper[*MediaBackfillRequest]
db *Database
log log.Logger
}
const (
getAllMediaBackfillRequestsForUserQuery = `
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
deleteAllMediaBackfillRequestsForUserQuery = "DELETE FROM media_backfill_requests WHERE user_mxid=$1"
upsertMediaBackfillRequestQuery = `
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
DO UPDATE SET
media_key=excluded.media_key,
status=excluded.status,
error=excluded.error
`
)
type MediaBackfillRequest struct {
db *Database
log log.Logger
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(ctx context.Context, userID id.UserID) ([]*MediaBackfillRequest, error) {
return mbrq.QueryMany(ctx, getAllMediaBackfillRequestsForUserQuery, userID)
UserID id.UserID
PortalKey *PortalKey
EventID id.EventID
MediaKey []byte
Status MediaBackfillRequestStatus
Error string
}
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(ctx context.Context, userID id.UserID) error {
return mbrq.Exec(ctx, deleteAllMediaBackfillRequestsForUserQuery, userID)
}
func newMediaBackfillRequest(qh *dbutil.QueryHelper[*MediaBackfillRequest]) *MediaBackfillRequest {
func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest {
return &MediaBackfillRequest{
qh: qh,
db: mbrq.db,
log: mbrq.log,
PortalKey: &PortalKey{},
}
}
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
return &MediaBackfillRequest{
qh: mbrq.QueryHelper,
db: mbrq.db,
log: mbrq.log,
UserID: userID,
PortalKey: portalKey,
EventID: eventID,
@ -82,25 +72,62 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID
}
}
type MediaBackfillRequest struct {
qh *dbutil.QueryHelper[*MediaBackfillRequest]
const (
getMediaBackfillRequestsForUser = `
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
FROM media_backfill_requests
WHERE user_mxid=$1
AND status=0
`
)
UserID id.UserID
PortalKey PortalKey
EventID id.EventID
MediaKey []byte
Status MediaBackfillRequestStatus
Error string
func (mbr *MediaBackfillRequest) Upsert() {
_, err := mbr.db.Exec(`
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
DO UPDATE SET
media_key=EXCLUDED.media_key,
status=EXCLUDED.status,
error=EXCLUDED.error`,
mbr.UserID,
mbr.PortalKey.JID.String(),
mbr.PortalKey.Receiver.String(),
mbr.EventID,
mbr.MediaKey,
mbr.Status,
mbr.Error)
if err != nil {
mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err)
}
}
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) (*MediaBackfillRequest, error) {
return dbutil.ValueOrErr(mbr, row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error))
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
mbr.log.Errorln("Database scan failed:", err)
}
return nil
}
return mbr
}
func (mbr *MediaBackfillRequest) sqlVariables() []any {
return []any{mbr.UserID, mbr.PortalKey.JID, mbr.PortalKey.Receiver, mbr.EventID, mbr.MediaKey, mbr.Status, mbr.Error}
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) {
rows, err := mbrq.db.Query(getMediaBackfillRequestsForUser, userID)
defer rows.Close()
if err != nil || rows == nil {
return nil
}
for rows.Next() {
requests = append(requests, mbrq.newMediaBackfillRequest().Scan(rows))
}
return
}
func (mbr *MediaBackfillRequest) Upsert(ctx context.Context) error {
return mbr.qh.Exec(ctx, upsertMediaBackfillRequestQuery, mbr.sqlVariables()...)
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) {
_, err := mbrq.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID)
if err != nil {
mbrq.log.Warnfln("Failed to delete media backfill requests for %s: %v", userID, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2021 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,22 +17,29 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type MessageQuery struct {
*dbutil.QueryHelper[*Message]
db *Database
log log.Logger
}
func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
return &Message{qh: qh}
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
log: mq.log,
}
}
const (
@ -60,47 +67,60 @@ const (
SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
`
insertMessageQuery = `
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
markMessageSentQuery = "UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4"
updateMessageMXIDQuery = "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
deleteMessageQuery = "DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3"
)
func (mq *MessageQuery) GetAll(ctx context.Context, chat PortalKey) ([]*Message, error) {
return mq.QueryMany(ctx, getAllMessagesQuery, chat.JID, chat.Receiver)
}
func (mq *MessageQuery) GetByJID(ctx context.Context, chat PortalKey, jid types.MessageID) (*Message, error) {
return mq.QueryOne(ctx, getMessageByJIDQuery, chat.JID, chat.Receiver, jid)
}
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
}
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.GetLastInChatBefore(ctx, chat, time.Now().Add(60*time.Second))
}
func (mq *MessageQuery) GetLastInChatBefore(ctx context.Context, chat PortalKey, maxTimestamp time.Time) (*Message, error) {
msg, err := mq.QueryOne(ctx, getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())
if msg != nil && msg.Timestamp.IsZero() {
// Old db, we don't know what the last message is.
msg = nil
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
if err != nil || rows == nil {
return nil
}
return msg, err
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
}
func (mq *MessageQuery) GetFirstInChat(ctx context.Context, chat PortalKey) (*Message, error) {
return mq.QueryOne(ctx, getFirstMessageInChatQuery, chat.JID, chat.Receiver)
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
}
func (mq *MessageQuery) GetMessagesBetween(ctx context.Context, chat PortalKey, minTimestamp, maxTimestamp time.Time) ([]*Message, error) {
return mq.QueryMany(ctx, getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
}
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
}
func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
if msg == nil || msg.Timestamp.IsZero() {
// Old db, we don't know what the last message is.
return nil
}
return msg
}
func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
}
func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
if err != nil || rows == nil {
return nil
}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return
}
func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
if row == nil {
return nil
}
return mq.New().Scan(row)
}
type MessageErrorType string
@ -124,7 +144,8 @@ const (
)
type Message struct {
qh *dbutil.QueryHelper[*Message]
db *Database
log log.Logger
Chat PortalKey
JID types.MessageID
@ -151,49 +172,76 @@ func (msg *Message) IsFakeJID() bool {
const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s"
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
func (msg *Message) Scan(row dbutil.Scannable) *Message {
var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil {
return nil, err
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
}
return nil
}
if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") {
_, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID)
if err != nil {
return nil, fmt.Errorf("failed to parse gallery MXID: %w", err)
msg.log.Errorln("Parsing gallery MXID failed:", err)
}
}
if ts != 0 {
msg.Timestamp = time.Unix(ts, 0)
}
return msg, nil
return msg
}
func (msg *Message) sqlVariables() []any {
func (msg *Message) Insert(txn dbutil.Execable) {
if txn == nil {
txn = msg.db
}
var sender interface{} = msg.Sender
// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
if msg.Sender.IsEmpty() {
sender = ""
}
mxid := msg.MXID.String()
if msg.GalleryPart != 0 {
mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid)
}
return []any{msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, msg.Sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID}
_, err := txn.Exec(`
INSERT INTO message
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
}
}
func (msg *Message) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
}
func (msg *Message) MarkSent(ctx context.Context, ts time.Time) error {
func (msg *Message) MarkSent(ts time.Time) {
msg.Sent = true
msg.Timestamp = ts
return msg.qh.Exec(ctx, markMessageSentQuery, ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
_, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
}
func (msg *Message) UpdateMXID(ctx context.Context, mxid id.EventID, newType MessageType, newError MessageErrorType) error {
func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) {
if txn == nil {
txn = msg.db
}
msg.MXID = mxid
msg.Type = newType
msg.Error = newError
return msg.qh.Exec(ctx, updateMessageMXIDQuery, mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
_, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6",
mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
}
}
func (msg *Message) Delete(ctx context.Context) error {
return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
func (msg *Message) Delete() {
_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil {
msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,56 +17,28 @@
package database
import (
"context"
"fmt"
"strings"
"github.com/lib/pq"
"go.mau.fi/util/dbutil"
)
const (
bulkPutPollOptionsQuery = "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
bulkPutPollOptionsQueryTemplate = "($1, $%d, $%d)"
bulkPutPollOptionsQueryPlaceholder = "($1, $2, $3)"
getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
getPollOptionQueryArrayPlaceholder = " = ANY($2)"
)
func init() {
if strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, "meow") == bulkPutPollOptionsQuery {
panic("Bulk insert query placeholder not found")
}
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
panic("Array select query placeholder not found")
}
}
type pollOption struct {
id string
hash [32]byte
}
func scanPollOption(rows dbutil.Scannable) (*pollOption, error) {
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
var hash []byte
var id string
err := rows.Scan(&id, &hash)
err = rows.Scan(&id, &hash)
if err != nil {
return nil, err
// return below
} else if len(hash) != 32 {
return nil, fmt.Errorf("unexpected hash length %d", len(hash))
err = fmt.Errorf("unexpected hash length %d", len(hash))
} else {
return &pollOption{id: id, hash: [32]byte(hash)}, nil
hashArr = *(*[32]byte)(hash)
}
return
}
func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string) error {
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
args := make([]any, len(opts)*2+1)
placeholders := make([]string, len(opts))
args[0] = msg.MXID
@ -75,47 +47,72 @@ func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string
args[i*2+1] = id
hashCopy := hash
args[i*2+2] = hashCopy[:]
placeholders[i] = fmt.Sprintf(bulkPutPollOptionsQueryTemplate, i*2+2, i*2+3)
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
i++
}
query := strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, strings.Join(placeholders, ","))
return msg.qh.Exec(ctx, query, args...)
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
_, err := msg.db.Exec(query, args...)
if err != nil {
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
}
}
func getPollOptions[LookupKey any, Key comparable, Value any](
ctx context.Context,
msg *Message,
query string,
things []LookupKey,
getKeyValue func(option *pollOption) (Key, Value),
) (map[Key]Value, error) {
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
var args []any
if msg.qh.GetDB().Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(things)}
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(hashes)}
} else {
query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
args = make([]any, len(things)+1)
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
args = make([]any, len(hashes)+1)
args[0] = msg.MXID
for i, thing := range things {
args[i+1] = thing
for i, hash := range hashes {
args[i+1] = hash
}
}
return dbutil.RowIterAsMap(
dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(msg.qh.GetDB().Query(ctx, query, args...)),
getKeyValue,
)
ids := make(map[[32]byte]string, len(hashes))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
break
}
ids[hash] = id
}
}
return ids
}
func (msg *Message) GetPollOptionIDs(ctx context.Context, hashes [][]byte) (map[[32]byte]string, error) {
return getPollOptions(
ctx, msg, getPollOptionIDsByHashesQuery, hashes,
func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
)
}
func (msg *Message) GetPollOptionHashes(ctx context.Context, ids []string) (map[string][32]byte, error) {
return getPollOptions(
ctx, msg, getPollOptionHashesByIDsQuery, ids,
func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
)
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
var args []any
if msg.db.Dialect == dbutil.Postgres {
args = []any{msg.MXID, pq.Array(ids)}
} else {
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
args = make([]any, len(ids)+1)
args[0] = msg.MXID
for i, id := range ids {
args[i+1] = id
}
}
hashes := make(map[string][32]byte, len(ids))
rows, err := msg.db.Query(query, args...)
if err != nil {
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
} else {
for rows.Next() {
id, hash, err := scanPollOptionMapping(rows)
if err != nil {
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
break
}
hashes[id] = hash
}
}
return hashes
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2021 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,14 +17,15 @@
package database
import (
"context"
"database/sql"
"fmt"
"time"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type PortalKey struct {
@ -52,89 +53,90 @@ func (key PortalKey) String() string {
}
type PortalQuery struct {
*dbutil.QueryHelper[*Portal]
db *Database
log log.Logger
}
func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
func (pq *PortalQuery) New() *Portal {
return &Portal{
qh: qh,
db: pq.db,
log: pq.log,
}
}
const (
getAllPortalsQuery = `
SELECT jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space,
first_event_id, next_batch_id, relay_user_id, expiration_time
FROM portal
`
getPortalByJIDQuery = getAllPortalsQuery + " WHERE jid=$1 AND receiver=$2"
getPortalByMXIDQuery = getAllPortalsQuery + " WHERE mxid=$1"
getPrivateChatsWithQuery = getAllPortalsQuery + " WHERE jid=$1"
getPrivateChatsOfQuery = getAllPortalsQuery + " WHERE receiver=$1"
getAllPortalsByParentGroupQuery = getAllPortalsQuery + " WHERE parent_group=$1"
findPrivateChatPortalsNotInSpaceQuery = `
const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
}
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
}
func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid)
}
func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
receiver = receiver.ToNonAD()
rows, err := pq.db.Query(`
SELECT jid FROM portal
LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL)
`
insertPortalQuery = `
INSERT INTO portal (
jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space,
first_event_id, next_batch_id, relay_user_id, expiration_time
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`
updatePortalQuery = `
UPDATE portal
SET mxid=$3, name=$4, name_set=$5, topic=$6, topic_set=$7, avatar=$8, avatar_url=$9, avatar_set=$10,
encrypted=$11, last_sync=$12, is_parent=$13, parent_group=$14, in_space=$15,
first_event_id=$16, next_batch_id=$17, relay_user_id=$18, expiration_time=$19
WHERE jid=$1 AND receiver=$2
`
clearPortalInSpaceQuery = "UPDATE portal SET in_space=false WHERE parent_group=$1"
deletePortalQuery = "DELETE FROM portal WHERE jid=$1 AND receiver=$2"
)
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery)
}
func (pq *PortalQuery) GetByJID(ctx context.Context, key PortalKey) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByJIDQuery, key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
}
func (pq *PortalQuery) GetAllByJID(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsWithQuery, jid.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChats(ctx context.Context, receiver types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getPrivateChatsOfQuery, receiver.ToNonAD())
}
func (pq *PortalQuery) GetAllByParentGroup(ctx context.Context, jid types.JID) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsByParentGroupQuery, jid)
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver types.JID) (keys []PortalKey, err error) {
receiver = receiver.ToNonAD()
scanFn := func(rows dbutil.Scannable) (key PortalKey, err error) {
key.Receiver = receiver
err = rows.Scan(&key.JID)
`, receiver)
if err != nil {
pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err)
return
} else if rows == nil {
return
}
return dbutil.ConvertRowFn[PortalKey](scanFn).
NewRowIter(pq.GetDB().Query(ctx, findPrivateChatPortalsNotInSpaceQuery, receiver)).
AsList()
for rows.Next() {
var key PortalKey
key.Receiver = receiver
err = rows.Scan(&key.JID)
if err == nil {
keys = append(keys, key)
}
}
return
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
type Portal struct {
qh *dbutil.QueryHelper[*Portal]
db *Database
log log.Logger
Key PortalKey
MXID id.RoomID
@ -159,17 +161,15 @@ type Portal struct {
ExpirationTime uint32
}
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
var lastSyncTs int64
err := row.Scan(
&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet,
&portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted,
&lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace,
&firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime,
)
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
if err != nil {
return nil, err
if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err)
}
return nil
}
if lastSyncTs > 0 {
portal.LastSync = time.Unix(lastSyncTs, 0)
@ -182,36 +182,96 @@ func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
portal.FirstEventID = id.EventID(firstEventID.String)
portal.NextBatchID = id.BatchID(nextBatchID.String)
portal.RelayUserID = id.UserID(relayUserID.String)
return portal, nil
return portal
}
func (portal *Portal) sqlVariables() []any {
var lastSyncTS int64
if !portal.LastSync.IsZero() {
lastSyncTS = portal.LastSync.Unix()
func (portal *Portal) mxidPtr() *id.RoomID {
if len(portal.MXID) > 0 {
return &portal.MXID
}
return []any{
portal.Key.JID, portal.Key.Receiver, dbutil.StrPtr(portal.MXID), portal.Name, portal.NameSet,
portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted,
lastSyncTS, portal.IsParent, dbutil.StrPtr(portal.ParentGroup.String()), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), dbutil.StrPtr(portal.RelayUserID), portal.ExpirationTime,
return nil
}
func (portal *Portal) relayUserPtr() *id.UserID {
if len(portal.RelayUserID) > 0 {
return &portal.RelayUserID
}
return nil
}
func (portal *Portal) parentGroupPtr() *string {
if !portal.ParentGroup.IsEmpty() {
val := portal.ParentGroup.String()
return &val
}
return nil
}
func (portal *Portal) lastSyncTs() int64 {
if portal.LastSync.IsZero() {
return 0
}
return portal.LastSync.Unix()
}
func (portal *Portal) Insert() {
_, err := portal.db.Exec(`
INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
relay_user_id, expiration_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`,
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
portal.relayUserPtr(), portal.ExpirationTime)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
}
}
func (portal *Portal) Insert(ctx context.Context) error {
return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
func (portal *Portal) Update(txn dbutil.Execable) {
if txn == nil {
txn = portal.db
}
_, err := txn.Exec(`
UPDATE portal
SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
WHERE jid=$18 AND receiver=$19
`, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
}
}
func (portal *Portal) Update(ctx context.Context) error {
return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
}
func (portal *Portal) Delete(ctx context.Context) error {
return portal.qh.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
err := portal.qh.Exec(ctx, clearPortalInSpaceQuery, portal.Key.JID)
func (portal *Portal) Delete() {
txn, err := portal.db.Begin()
if err != nil {
portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
return
}
defer func() {
if err != nil {
return err
err = txn.Rollback()
if err != nil {
portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
}
} else if err = txn.Commit(); err != nil {
portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
}
return portal.qh.Exec(ctx, deletePortalQuery, portal.Key.JID, portal.Key.Receiver)
})
}()
_, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID)
if err != nil {
portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
return
}
_, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2021 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,70 +17,74 @@
package database
import (
"context"
"database/sql"
"time"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type PuppetQuery struct {
*dbutil.QueryHelper[*Puppet]
db *Database
log log.Logger
}
func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
func (pq *PuppetQuery) New() *Puppet {
return &Puppet{
qh: qh,
db: pq.db,
log: pq.log,
EnablePresence: true,
EnableReceipts: true,
}
}
const (
getAllPuppetsQuery = `
SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts
FROM puppet
`
getPuppetByJIDQuery = getAllPuppetsQuery + " WHERE username=$1"
getPuppetByCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid=$1"
getAllPuppetsWithCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid<>''"
insertPuppetQuery = `
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
updatePuppetQuery = `
UPDATE puppet
SET avatar=$2, avatar_url=$3, avatar_set=$4, displayname=$5, name_quality=$6, name_set=$7, contact_info_set=$8,
last_sync=$9, custom_mxid=$10, access_token=$11, next_batch=$12, enable_presence=$13, enable_receipts=$14
WHERE username=$1
`
)
func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) {
return pq.QueryMany(ctx, getAllPuppetsQuery)
func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet")
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return
}
func (pq *PuppetQuery) Get(ctx context.Context, jid types.JID) (*Puppet, error) {
return pq.QueryOne(ctx, getPuppetByJIDQuery, jid.User)
func (pq *PuppetQuery) Get(jid types.JID) *Puppet {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) {
return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid)
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) {
return pq.QueryMany(ctx, getAllPuppetsWithCustomMXIDQuery)
func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''")
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return
}
type Puppet struct {
qh *dbutil.QueryHelper[*Puppet]
db *Database
log log.Logger
JID types.JID
Avatar string
@ -99,14 +103,17 @@ type Puppet struct {
EnableReceipts bool
}
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality, lastSync sql.NullInt64
var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool
var username string
err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
if err != nil {
return nil, err
if err != sql.ErrNoRows {
puppet.log.Errorln("Database scan failed:", err)
}
return nil
}
puppet.JID = types.NewJID(username, types.DefaultUserServer)
puppet.Displayname = displayname.String
@ -124,30 +131,45 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
puppet.NextBatch = nextBatch.String
puppet.EnablePresence = enablePresence.Bool
puppet.EnableReceipts = enableReceipts.Bool
return puppet, nil
return puppet
}
func (puppet *Puppet) sqlVariables() []any {
var lastSyncTS int64
if !puppet.LastSync.IsZero() {
lastSyncTS = puppet.LastSync.Unix()
}
return []any{
puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTS,
puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
puppet.EnablePresence, puppet.EnableReceipts,
}
}
func (puppet *Puppet) Insert(ctx context.Context) error {
func (puppet *Puppet) Insert() {
if puppet.JID.Server != types.DefaultUserServer {
zerolog.Ctx(ctx).Warn().Stringer("jid", puppet.JID).Msg("Not inserting puppet: not a user")
return nil
puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID)
return
}
var lastSyncTs int64
if !puppet.LastSync.IsZero() {
lastSyncTs = puppet.LastSync.Unix()
}
_, err := puppet.db.Exec(`
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
puppet.EnablePresence, puppet.EnableReceipts,
)
if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
}
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
}
func (puppet *Puppet) Update(ctx context.Context) error {
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
func (puppet *Puppet) Update() {
var lastSyncTs int64
if !puppet.LastSync.IsZero() {
lastSyncTs = puppet.LastSync.Unix()
}
_, err := puppet.db.Exec(`
UPDATE puppet
SET displayname=$1, name_quality=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, contact_info_set=$7, last_sync=$8,
custom_mxid=$9, access_token=$10, next_batch=$11, enable_presence=$12, enable_receipts=$13
WHERE username=$14
`, puppet.Displayname, puppet.NameQuality, puppet.NameSet, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.ContactInfoSet,
lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts,
puppet.JID.User)
if err != nil {
puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,20 +17,26 @@
package database
import (
"context"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"database/sql"
"errors"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type ReactionQuery struct {
*dbutil.QueryHelper[*Reaction]
db *Database
log log.Logger
}
func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
return &Reaction{qh: qh}
func (rq *ReactionQuery) New() *Reaction {
return &Reaction{
db: rq.db,
log: rq.log,
}
}
const (
@ -49,20 +55,28 @@ const (
DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid
`
deleteReactionQuery = `
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 AND mxid=$5
`
)
func (rq *ReactionQuery) GetByTargetJID(ctx context.Context, chat PortalKey, jid types.MessageID, sender types.JID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())
func (rq *ReactionQuery) GetByTargetJID(chat PortalKey, jid types.MessageID, sender types.JID) *Reaction {
return rq.maybeScan(rq.db.QueryRow(getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD()))
}
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction {
return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid))
}
func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction {
if row == nil {
return nil
}
return rq.New().Scan(row)
}
type Reaction struct {
qh *dbutil.QueryHelper[*Reaction]
db *Database
log log.Logger
Chat PortalKey
TargetJID types.MessageID
@ -71,19 +85,35 @@ type Reaction struct {
JID types.MessageID
}
func (reaction *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
return dbutil.ValueOrErr(reaction, row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID))
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
reaction.log.Errorln("Database scan failed:", err)
}
return nil
}
return reaction
}
func (reaction *Reaction) sqlVariables() []any {
func (reaction *Reaction) Upsert(txn dbutil.Execable) {
reaction.Sender = reaction.Sender.ToNonAD()
return []any{reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID}
if txn == nil {
txn = reaction.db
}
_, err := txn.Exec(upsertReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID)
if err != nil {
reaction.log.Warnfln("Failed to upsert reaction to %s@%s by %s: %v", reaction.Chat, reaction.TargetJID, reaction.Sender, err)
}
}
func (reaction *Reaction) Upsert(ctx context.Context) error {
return reaction.qh.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...)
func (reaction *Reaction) GetTarget() *Message {
return reaction.db.Message.GetByJID(reaction.Chat, reaction.TargetJID)
}
func (reaction *Reaction) Delete(ctx context.Context) error {
return reaction.qh.Exec(ctx, deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender)
func (reaction *Reaction) Delete() {
_, err := reaction.db.Exec(deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID)
if err != nil {
reaction.log.Warnfln("Failed to delete reaction %s: %v", reaction.MXID, err)
}
}

View file

@ -17,7 +17,6 @@
package upgrades
import (
"context"
"embed"
"errors"
@ -30,7 +29,7 @@ var Table dbutil.UpgradeTable
var rawUpgrades embed.FS
func init() {
Table.Register(-1, 35, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
Table.Register(-1, 35, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
})
Table.RegisterFS(rawUpgrades)

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,65 +17,63 @@
package database
import (
"context"
"database/sql"
"sync"
"time"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/id"
"go.mau.fi/util/dbutil"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type UserQuery struct {
*dbutil.QueryHelper[*User]
db *Database
log log.Logger
}
func newUser(qh *dbutil.QueryHelper[*User]) *User {
func (uq *UserQuery) New() *User {
return &User{
qh: qh,
db: uq.db,
log: uq.log,
lastReadCache: make(map[PortalKey]time.Time),
inSpaceCache: make(map[PortalKey]bool),
}
}
const (
getAllUsersQuery = `SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`
getUserByMXIDQuery = getAllUsersQuery + ` WHERE mxid=$1`
getUserByUsernameQuery = getAllUsersQuery + ` WHERE username=$1`
insertUserQuery = `
INSERT INTO "user" (
mxid, username, agent, device,
management_room, space_room,
phone_last_seen, phone_last_pinged, timezone
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
updateUserQuery = `
UPDATE "user"
SET username=$1, agent=$2, device=$3,
management_room=$4, space_room=$5,
phone_last_seen=$6, phone_last_pinged=$7, timezone=$8
WHERE mxid=$9
`
getUserLastAppStateKeyIDQuery = "SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1"
)
func (uq *UserQuery) GetAll(ctx context.Context) ([]*User, error) {
return uq.QueryMany(ctx, getAllUsersQuery)
func (uq *UserQuery) GetAll() (users []*User) {
rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
users = append(users, uq.New().Scan(rows))
}
return
}
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetByUsername(ctx context.Context, username string) (*User, error) {
return uq.QueryOne(ctx, getUserByUsernameQuery, username)
func (uq *UserQuery) GetByUsername(username string) *User {
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
type User struct {
qh *dbutil.QueryHelper[*User]
db *Database
log log.Logger
MXID id.UserID
JID types.JID
@ -91,21 +89,20 @@ type User struct {
inSpaceCacheLock sync.Mutex
}
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
func (user *User) Scan(row dbutil.Scannable) *User {
var username, timezone sql.NullString
var device, agent sql.NullInt16
var device, agent sql.NullByte
var phoneLastSeen, phoneLastPinged sql.NullInt64
err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone)
if err != nil {
return nil, err
if err != sql.ErrNoRows {
user.log.Errorln("Database scan failed:", err)
}
return nil
}
user.Timezone = timezone.String
if len(username.String) > 0 {
user.JID = types.JID{
User: username.String,
Device: uint16(device.Int16),
Server: types.DefaultUserServer,
}
user.JID = types.NewADJID(username.String, agent.Byte, device.Byte)
}
if phoneLastSeen.Valid {
user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
@ -113,34 +110,66 @@ func (user *User) Scan(row dbutil.Scannable) (*User, error) {
if phoneLastPinged.Valid {
user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0)
}
return user, nil
return user
}
func (user *User) sqlVariables() []any {
var username *string
var agent, device *uint16
func (user *User) usernamePtr() *string {
if !user.JID.IsEmpty() {
username = dbutil.StrPtr(user.JID.User)
var zero uint16
agent = &zero
device = dbutil.NumPtr(user.JID.Device)
return &user.JID.User
}
return []any{
username, agent, device, user.ManagementRoom, user.SpaceRoom,
dbutil.UnixPtr(user.PhoneLastSeen), dbutil.UnixPtr(user.PhoneLastPinged),
user.Timezone, user.MXID,
return nil
}
func (user *User) agentPtr() *uint8 {
if !user.JID.IsEmpty() {
zero := uint8(0)
return &zero
}
return nil
}
func (user *User) devicePtr() *uint8 {
if !user.JID.IsEmpty() {
device := uint8(user.JID.Device)
return &device
}
return nil
}
func (user *User) phoneLastSeenPtr() *int64 {
if user.PhoneLastSeen.IsZero() {
return nil
}
ts := user.PhoneLastSeen.Unix()
return &ts
}
func (user *User) phoneLastPingedPtr() *int64 {
if user.PhoneLastPinged.IsZero() {
return nil
}
ts := user.PhoneLastPinged.Unix()
return &ts
}
func (user *User) Insert() {
_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone)
if err != nil {
user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
}
}
func (user *User) Insert(ctx context.Context) error {
return user.qh.Exec(ctx, insertUserQuery, user.sqlVariables()...)
func (user *User) Update() {
_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`,
user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID)
if err != nil {
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
}
}
func (user *User) Update(ctx context.Context) error {
return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
}
func (user *User) GetLastAppStateKeyID(ctx context.Context) (keyID []byte, err error) {
err = user.qh.GetDB().QueryRow(ctx, getUserLastAppStateKeyIDQuery, user.JID).Scan(&keyID)
return
func (user *User) GetLastAppStateKeyID() ([]byte, error) {
var keyID []byte
err := user.db.QueryRow("SELECT key_id FROM whatsmeow_app_state_sync_keys ORDER BY timestamp DESC LIMIT 1").Scan(&keyID)
return keyID, err
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2021 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,97 +17,69 @@
package database
import (
"context"
"database/sql"
"errors"
"time"
"github.com/rs/zerolog"
)
const (
getLastReadTSQuery = "SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setLastReadTSQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`
getIsInSpaceQuery = "SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
setIsInSpaceQuery = `
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`
)
func (user *User) GetLastReadTS(ctx context.Context, portal PortalKey) time.Time {
func (user *User) GetLastReadTS(portal PortalKey) time.Time {
user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock()
if cached, ok := user.lastReadCache[portal]; ok {
return cached
}
var ts int64
var parsedTS time.Time
err := user.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, user.MXID, portal.JID, portal.Receiver).Scan(&ts)
err := user.db.QueryRow("SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&ts)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query last read timestamp")
return parsedTS
user.log.Warnfln("Failed to scan last read timestamp from user portal table: %v", err)
}
if ts != 0 {
parsedTS = time.Unix(ts, 0)
if ts == 0 {
user.lastReadCache[portal] = time.Time{}
} else {
user.lastReadCache[portal] = time.Unix(ts, 0)
}
user.lastReadCache[portal] = parsedTS
return user.lastReadCache[portal]
}
func (user *User) SetLastReadTS(ctx context.Context, portal PortalKey, ts time.Time) {
func (user *User) SetLastReadTS(portal PortalKey, ts time.Time) {
user.lastReadCacheLock.Lock()
defer user.lastReadCacheLock.Unlock()
_, err := user.qh.GetDB().Exec(ctx, setLastReadTSQuery, user.MXID, portal.JID, portal.Receiver, ts.Unix())
_, err := user.db.Exec(`
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
`, user.MXID, portal.JID, portal.Receiver, ts.Unix())
if err != nil {
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update last read timestamp")
user.log.Warnfln("Failed to update last read timestamp: %v", err)
} else {
zerolog.Ctx(ctx).Debug().
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Time("last_read_ts", ts).
Msg("Updated last read timestamp of portal")
user.log.Debugfln("Set last read timestamp of %s in %s to %d", user.MXID, portal.String(), ts.Unix())
user.lastReadCache[portal] = ts
}
}
func (user *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
func (user *User) IsInSpace(portal PortalKey) bool {
user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock()
if cached, ok := user.inSpaceCache[portal]; ok {
return cached
}
var inSpace bool
err := user.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
err := user.db.QueryRow("SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to query in space status")
return false
user.log.Warnfln("Failed to scan in space status from user portal table: %v", err)
}
user.inSpaceCache[portal] = inSpace
return inSpace
}
func (user *User) MarkInSpace(ctx context.Context, portal PortalKey) {
func (user *User) MarkInSpace(portal PortalKey) {
user.inSpaceCacheLock.Lock()
defer user.inSpaceCacheLock.Unlock()
_, err := user.qh.GetDB().Exec(ctx, setIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver)
_, err := user.db.Exec(`
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
`, user.MXID, portal.JID, portal.Receiver)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Str("user_id", user.MXID.String()).
Any("portal_key", portal).
Msg("Failed to update in space status")
user.log.Warnfln("Failed to update in space status: %v", err)
} else {
user.inSpaceCache[portal] = true
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,11 +17,10 @@
package main
import (
"context"
"fmt"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@ -29,74 +28,47 @@ import (
"maunium.net/go/mautrix-whatsapp/database"
)
func (portal *Portal) MarkDisappearing(ctx context.Context, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
if expiresIn == 0 {
return
}
expiresAt := startsAt.Add(expiresIn)
msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, expiresIn, expiresAt)
err := msg.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to insert disappearing message")
}
msg.Insert(txn)
if expiresAt.Before(time.Now().Add(1 * time.Hour)) {
go portal.sleepAndDelete(context.WithoutCancel(ctx), msg)
go portal.sleepAndDelete(msg)
}
}
func (br *WABridge) SleepAndDeleteUpcoming(ctx context.Context) {
msgs, err := br.DB.DisappearingMessage.GetUpcomingScheduled(ctx, 1*time.Hour)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get upcoming disappearing messages")
return
}
for _, msg := range msgs {
func (br *WABridge) SleepAndDeleteUpcoming() {
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
portal := br.GetPortalByMXID(msg.RoomID)
if portal == nil {
err = msg.Delete(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", msg.EventID).
Msg("Failed to delete disappearing message row with no portal")
}
msg.Delete()
} else {
go portal.sleepAndDelete(ctx, msg)
go portal.sleepAndDelete(msg)
}
}
}
func (portal *Portal) sleepAndDelete(ctx context.Context, msg *database.DisappearingMessage) {
func (portal *Portal) sleepAndDelete(msg *database.DisappearingMessage) {
if _, alreadySleeping := portal.currentlySleepingToDelete.LoadOrStore(msg.EventID, true); alreadySleeping {
return
}
defer portal.currentlySleepingToDelete.Delete(msg.EventID)
log := zerolog.Ctx(ctx)
sleepTime := msg.ExpireAt.Sub(time.Now())
log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Dur("sleep_time", sleepTime).
Msg("Sleeping before making message disappear")
portal.log.Debugfln("Sleeping for %s to make %s disappear", sleepTime, msg.EventID)
time.Sleep(sleepTime)
_, err := portal.MainIntent().RedactEvent(ctx, msg.RoomID, msg.EventID, mautrix.ReqRedact{
_, err := portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{
Reason: "Message expired",
TxnID: fmt.Sprintf("mxwa_disappear_%s", msg.EventID),
})
if err != nil {
log.Err(err).
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Failed to make event disappear")
portal.log.Warnfln("Failed to make %s disappear: %v", msg.EventID, err)
} else {
log.Debug().
Stringer("room_id", portal.MXID).
Stringer("event_id", msg.EventID).
Msg("Disappeared event")
}
err = msg.Delete(ctx)
if err != nil {
log.Err(err).Msg("Failed to delete disapperaing message row in database after redacting event")
portal.log.Debugfln("Disappeared %s", msg.EventID)
}
msg.Delete()
}

View file

@ -17,14 +17,12 @@
package main
import (
"context"
"fmt"
"html"
"regexp"
"sort"
"strings"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/event"
@ -106,27 +104,22 @@ func NewFormatter(bridge *WABridge) *Formatter {
return formatter
}
func (formatter *Formatter) getMatrixInfoByJID(ctx context.Context, roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
func (formatter *Formatter) getMatrixInfoByJID(roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
mxid = puppet.MXID
displayname = puppet.Displayname
}
if user := formatter.bridge.GetUserByJID(jid); user != nil {
mxid = user.MXID
member, err := formatter.bridge.StateStore.GetMember(ctx, roomID, user.MXID)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("user_id", user.MXID).
Msg("Failed to get member profile from state store")
} else if len(member.Displayname) > 0 {
member := formatter.bridge.StateStore.GetMember(roomID, user.MXID)
if len(member.Displayname) > 0 {
displayname = member.Displayname
}
}
return
}
func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
output := html.EscapeString(content.Body)
for regex, replacement := range formatter.waReplString {
output = regex.ReplaceAllString(output, replacement)
@ -152,7 +145,7 @@ func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID,
// TODO lid support?
continue
}
mxid, displayname := formatter.getMatrixInfoByJID(ctx, roomID, jid)
mxid, displayname := formatter.getMatrixInfoByJID(roomID, jid)
number := "@" + jid.User
output = strings.ReplaceAll(output, number, fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname))
content.Body = strings.ReplaceAll(content.Body, number, displayname)

45
go.mod
View file

@ -1,25 +1,25 @@
module maunium.net/go/mautrix-whatsapp
go 1.21
go 1.20
require (
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.22
github.com/prometheus/client_golang v1.19.0
github.com/rs/zerolog v1.32.0
github.com/mattn/go-sqlite3 v1.14.19
github.com/prometheus/client_golang v1.17.0
github.com/rs/zerolog v1.31.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/tidwall/gjson v1.17.1
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e
github.com/tidwall/gjson v1.17.0
go.mau.fi/util v0.2.1
go.mau.fi/webp v0.1.0
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/image v0.15.0
golang.org/x/net v0.22.0
google.golang.org/protobuf v1.33.0
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611
golang.org/x/image v0.14.0
golang.org/x/net v0.19.0
google.golang.org/protobuf v1.31.0
maunium.net/go/maulogger/v2 v2.4.1
maunium.net/go/mautrix v0.16.2
)
require (
@ -27,27 +27,24 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.7.0 // indirect
github.com/yuin/goldmark v1.6.0 // indirect
go.mau.fi/libsignal v0.1.0 // indirect
go.mau.fi/zeroconfig v0.1.2 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/crypto v0.16.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
maunium.net/go/mauflag v1.0.0 // indirect
)
//replace go.mau.fi/util => ../../Go/go-util
//replace maunium.net/go/mautrix => ../mautrix-go

100
go.sum
View file

@ -1,9 +1,6 @@
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c h1:WqjRVgUO039eiISCjsZC4F9onOEV93DJAk6v33rsZzY=
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c/go.mod h1:b9FFm9y4mEm36G8ytVmS1vkNzJa0KepmcdVY+qf7qRU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
@ -12,18 +9,18 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -33,75 +30,78 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c=
go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I=
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e h1:e1jDj/MjleSS5r9DMRbuCZYKy5Rr+sbsu8eWjtLqrGk=
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo=
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw=
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c=
go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA=
go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8=
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 h1:QOGjCIh2WEfkgX/38KLjnNof79GWx0T+KLrhTHiws3s=
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462/go.mod h1:lQHbhaG/fI+6hfGqz5Vzn2OBJBEZ05H0kCP6iJXriN4=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 h1:+teJYCOK6M4Kn2TYCj29levhHVwnJTmgCtEXLtgwQtM=
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735/go.mod h1:5xTtHNaZpGni6z6aE1iEopjW7wNgsKcolZxZrOujK9M=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 h1:qCEDpW1G+vcj3Y7Fy52pEM1AWm3abj8WimGYejI3SC4=
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa h1:TLSWIAWKIWxLghgzWfp7o92pVCcFR3yLsArc0s/tsMs=
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa/go.mod h1:0sfLB2ejW+lhgio4UlZMmn5i9SuZ8mxFkonFSamrfTE=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=

View file

@ -17,7 +17,6 @@
package main
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
@ -25,11 +24,11 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/variationselector"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/util/variationselector"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
@ -65,8 +64,9 @@ func (user *User) handleHistorySyncsLoop() {
if batchSend {
// Start the backfill queue.
user.BackfillQueue = &BackfillQueue{
BackfillQuery: user.bridge.DB.BackfillQueue,
BackfillQuery: user.bridge.DB.Backfill,
reCheckChannels: []chan bool{},
log: user.log.Sub("BackfillQueue"),
}
forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward}
@ -109,113 +109,80 @@ func (user *User) handleHistorySyncsLoop() {
const EnqueueBackfillsDelay = 30 * time.Second
func (user *User) enqueueAllBackfills() {
log := user.zlog.With().
Str("method", "User.enqueueAllBackfills").
Logger()
ctx := log.WithContext(context.TODO())
nMostRecent, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
if err != nil {
log.Err(err).Msg("Failed to get recent history sync conversations from database")
return
} else if len(nMostRecent) == 0 {
return
}
log.Info().
Int("chat_count", len(nMostRecent)).
Msg("Enqueueing backfills for recent chats in history sync")
// Find the portals for all the conversations.
portals := make([]*Portal, 0, len(nMostRecent))
for _, conv := range nMostRecent {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).Str("conversation_id", conv.ConversationID).Msg("Failed to parse chat JID in history sync")
continue
nMostRecent := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
if len(nMostRecent) > 0 {
user.log.Infofln("%v has passed since the last history sync blob, enqueueing backfills for %d chats", EnqueueBackfillsDelay, len(nMostRecent))
// Find the portals for all the conversations.
portals := []*Portal{}
for _, conv := range nMostRecent {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
user.log.Warnfln("Failed to parse chat JID '%s' in history sync: %v", conv.ConversationID, err)
continue
}
portals = append(portals, user.GetPortalByJID(jid))
}
portals = append(portals, user.GetPortalByJID(jid))
user.EnqueueImmediateBackfills(portals)
user.EnqueueForwardBackfills(portals)
user.EnqueueDeferredBackfills(portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
}
user.EnqueueImmediateBackfills(ctx, portals)
user.EnqueueForwardBackfills(ctx, portals)
user.EnqueueDeferredBackfills(ctx, portals)
// Tell the queue to check for new backfill requests.
user.BackfillQueue.ReCheck()
}
func (user *User) backfillAll() {
log := user.zlog.With().
Str("method", "User.backfillAll").
Logger()
ctx := log.WithContext(context.TODO())
conversations, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, -1)
if err != nil {
log.Err(err).Msg("Failed to get history sync conversations from database")
return
} else if len(conversations) == 0 {
return
}
log.Info().
Int("conversation_count", len(conversations)).
Msg("Probably received all history sync blobs, now backfilling conversations")
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
bridgedCount := 0
// Find the portals for all the conversations.
for _, conv := range conversations {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).
Str("conversation_id", conv.ConversationID).
Msg("Failed to parse chat JID in history sync")
continue
}
portal := user.GetPortalByJID(jid)
if portal.MXID != "" {
log.Debug().
Str("portal_jid", portal.Key.JID.String()).
Msg("Chat already has a room, deleting messages from database")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
conversations := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, -1)
if len(conversations) > 0 {
user.zlog.Info().
Int("conversation_count", len(conversations)).
Msg("Probably received all history sync blobs, now backfilling conversations")
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
bridgedCount := 0
// Find the portals for all the conversations.
for _, conv := range conversations {
jid, err := types.ParseJID(conv.ConversationID)
if err != nil {
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
Msg("Failed to delete history sync conversation with existing portal from database")
user.zlog.Warn().Err(err).
Str("conversation_id", conv.ConversationID).
Msg("Failed to parse chat JID in history sync")
continue
}
bridgedCount++
} else if hasMessages, err := user.bridge.DB.HistorySync.ConversationHasMessages(ctx, user.MXID, portal.Key); err != nil {
log.Err(err).Str("portal_jid", portal.Key.JID.String()).Msg("Failed to check if chat has messages in history sync")
} else if !hasMessages {
log.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
Msg("Failed to delete history sync conversation with no messages from database")
}
} else if limit < 0 || bridgedCount < limit {
bridgedCount++
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, true)
if err != nil {
log.Err(err).Msg("Failed to create Matrix room for backfill")
portal := user.GetPortalByJID(jid)
if portal.MXID != "" {
user.zlog.Debug().
Str("portal_jid", portal.Key.JID.String()).
Msg("Chat already has a room, deleting messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
bridgedCount++
} else if !user.bridge.DB.HistorySync.ConversationHasMessages(user.MXID, portal.Key) {
user.zlog.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
} else if limit < 0 || bridgedCount < limit {
bridgedCount++
err = portal.CreateMatrixRoom(user, nil, nil, true, true)
if err != nil {
user.zlog.Err(err).Msg("Failed to create Matrix room for backfill")
}
}
}
}
}
func (portal *Portal) legacyBackfill(ctx context.Context, user *User) {
func (portal *Portal) legacyBackfill(user *User) {
defer portal.latestEventBackfillLock.Unlock()
// This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room.
if portal.latestEventBackfillLock.TryLock() {
panic("legacyBackfill() called without locking latestEventBackfillLock")
}
log := zerolog.Ctx(ctx).With().Str("action", "legacy backfill").Logger()
ctx = log.WithContext(ctx)
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get history sync conversation data for backfill")
return
}
messages, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
if err != nil {
log.Err(err).Msg("Failed to get history sync messages for backfill")
return
}
// TODO use portal.zlog instead of user.zlog
log := user.zlog.With().
Str("portal_jid", portal.Key.JID.String()).
Str("action", "legacy backfill").
Logger()
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, portal.Key)
messages := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database")
for i := len(messages) - 1; i >= 0; i-- {
msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i])
@ -227,26 +194,18 @@ func (portal *Portal) legacyBackfill(ctx context.Context, user *User) {
Msg("Dropping historical message due to parse error")
continue
}
ctx := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger().
WithContext(ctx)
portal.handleMessage(ctx, user, msgEvt, true)
portal.handleMessage(user, msgEvt, true)
}
if conv != nil {
isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead := !isUnread || isTooOld
if shouldMarkAsRead {
user.markSelfReadFull(ctx, portal)
user.markSelfReadFull(portal)
}
}
log.Info().Msg("Backfill complete, deleting leftover messages from database")
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
if err != nil {
log.Err(err).Msg("Failed to delete history sync conversation from database after backfill")
}
log.Debug().Msg("Backfill complete, deleting leftover messages from database")
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
}
func (user *User) dailyMediaRequestLoop() {
@ -265,49 +224,29 @@ func (user *User) dailyMediaRequestLoop() {
if requestStartTime.Before(now) {
requestStartTime = requestStartTime.AddDate(0, 0, 1)
}
log := user.zlog.With().
Str("action", "daily media request loop").
Logger()
ctx := log.WithContext(context.Background())
// Wait to start the loop
log.Info().Time("start_loop_at", requestStartTime).Msg("Waiting until start time to do media retry requests")
user.log.Infof("Waiting until %s to do media retry requests", requestStartTime)
time.Sleep(time.Until(requestStartTime))
for {
mediaBackfillRequests, err := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(ctx, user.MXID)
if err != nil {
log.Err(err).Msg("Failed to get media retry requests")
} else if len(mediaBackfillRequests) > 0 {
log.Info().Int("media_request_count", len(mediaBackfillRequests)).Msg("Sending media retry requests")
mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID)
user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests))
// Send all the media backfill requests for the user at once
for _, req := range mediaBackfillRequests {
portal := user.GetPortalByJID(req.PortalKey.JID)
_, err = portal.requestMediaRetry(ctx, user, req.EventID, req.MediaKey)
if err != nil {
log.Err(err).
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Failed to send media retry request")
req.Status = database.MediaBackfillRequestStatusRequestFailed
req.Error = err.Error()
} else {
log.Debug().
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Sent media retry request")
req.Status = database.MediaBackfillRequestStatusRequested
}
req.MediaKey = nil
err = req.Upsert(ctx)
if err != nil {
log.Err(err).
Stringer("portal_key", req.PortalKey).
Stringer("event_id", req.EventID).
Msg("Failed to save status of media retry request")
}
// Send all of the media backfill requests for the user at once
for _, req := range mediaBackfillRequests {
portal := user.GetPortalByJID(req.PortalKey.JID)
_, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey)
if err != nil {
user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID)
req.Status = database.MediaBackfillRequestStatusRequestFailed
req.Error = err.Error()
} else {
user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID)
req.Status = database.MediaBackfillRequestStatusRequested
}
req.MediaKey = nil
req.Upsert()
}
// Wait for 24 hours before making requests again
@ -315,29 +254,20 @@ func (user *User) dailyMediaRequestLoop() {
}
}
func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTask, conv *database.HistorySyncConversation, portal *Portal) {
func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) {
portal.backfillLock.Lock()
defer portal.backfillLock.Unlock()
log := zerolog.Ctx(ctx)
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(ctx, portal.MXID, user.MXID) {
portal.ensureUserInvited(ctx, user)
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(portal.MXID, user.MXID) {
portal.ensureUserInvited(user)
}
backfillState, err := user.bridge.DB.BackfillState.GetBackfillState(ctx, user.MXID, portal.Key)
backfillState := user.bridge.DB.Backfill.GetBackfillState(user.MXID, &portal.Key)
if backfillState == nil {
backfillState = user.bridge.DB.BackfillState.NewBackfillState(user.MXID, portal.Key)
backfillState = user.bridge.DB.Backfill.NewBackfillState(user.MXID, &portal.Key)
}
err = backfillState.SetProcessingBatch(ctx, true)
if err != nil {
log.Err(err).Msg("Failed to mark batch as being processed")
}
defer func() {
err = backfillState.SetProcessingBatch(ctx, false)
if err != nil {
log.Err(err).Msg("Failed to mark batch as no longer being processed")
}
}()
backfillState.SetProcessingBatch(true)
defer backfillState.SetProcessingBatch(false)
var timeEnd *time.Time
var forward, shouldMarkAsRead bool
@ -345,27 +275,17 @@ func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTa
if req.BackfillType == database.BackfillForward {
// TODO this overrides the TimeStart set when enqueuing the backfill
// maybe the enqueue should instead include the prev event ID
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get newest message in chat")
return
}
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
start := lastMessage.Timestamp.Add(1 * time.Second)
req.TimeStart = &start
// Sending events at the end of the room (= latest events)
forward = true
} else {
firstMessage, err := portal.bridge.DB.Message.GetFirstInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get oldest message in chat")
return
}
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key)
if firstMessage != nil {
end := firstMessage.Timestamp.Add(-1 * time.Second)
timeEnd = &end
log.Debug().
Time("oldest_message_ts", firstMessage.Timestamp).
Msg("Limiting backfill to messages older than oldest message")
user.log.Debugfln("Limiting backfill to end at %v", end)
} else {
// Portal is empty -> events are latest
forward = true
@ -383,48 +303,45 @@ func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTa
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
shouldMarkAsRead = !isUnread || isTooOld
}
allMsgs, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
sendDisappearedNotice := false
// If expired messages are on, and a notice has not been sent to this chat
// about it having disappeared messages at the conversation timestamp, send
// a notice indicating so.
if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 {
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
log.Err(err).Msg("Failed to get last message in chat to check if disappeared notice should be sent")
}
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) {
sendDisappearedNotice = true
}
}
if !sendDisappearedNotice && len(allMsgs) == 0 {
log.Debug().Msg("Not backfilling chat: no bridgeable messages found")
user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID)
return
}
if len(portal.MXID) == 0 {
log.Debug().Msg("Creating portal for chat as part of history sync handling")
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, false)
user.log.Debugln("Creating portal for", portal.Key.JID, "as part of history sync handling")
err := portal.CreateMatrixRoom(user, nil, nil, true, false)
if err != nil {
log.Err(err).Msg("Failed to create room for chat during backfill")
user.log.Errorfln("Failed to create room for %s during backfill: %v", portal.Key.JID, err)
return
}
}
// Update the backfill status here after the room has been created.
portal.updateBackfillStatus(ctx, backfillState)
portal.updateBackfillStatus(backfillState)
if sendDisappearedNotice {
log.Debug().Time("last_message_time", conv.LastMessageTimestamp).
Msg("Sending notice that there are disappeared messages in the chat")
resp, err := portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
user.log.Debugfln("Sending notice to %s that there are disappeared messages ending at %v", portal.Key.JID, conv.LastMessageTimestamp)
resp, err := portal.sendMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice,
Body: portal.formatDisappearingMessageNotice(),
}, nil, conv.LastMessageTimestamp.UnixMilli())
if err != nil {
log.Err(err).Msg("Failed to send disappeared messages notice event")
portal.log.Errorln("Error sending disappearing messages notice event")
return
}
@ -436,18 +353,12 @@ func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTa
msg.SenderMXID = portal.MainIntent().UserID
msg.Sent = true
msg.Type = database.MsgFake
err = msg.Insert(ctx)
if err != nil {
log.Err(err).Msg("Failed to save fake message entry for disappearing message timer in backfill")
}
user.markSelfReadFull(ctx, portal)
msg.Insert(nil)
user.markSelfReadFull(portal)
return
}
log.Info().
Int("message_count", len(allMsgs)).
Int("max_batch_events", req.MaxBatchEvents).
Msg("Backfilling messages")
user.log.Infofln("Backfilling %d messages in %s, %d messages at a time (queue ID: %d)", len(allMsgs), portal.Key.JID, req.MaxBatchEvents, req.QueueID)
toBackfill := allMsgs[0:]
for len(toBackfill) > 0 {
var msgs []*waProto.WebMessageInfo
@ -461,14 +372,14 @@ func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTa
if len(msgs) > 0 {
time.Sleep(time.Duration(req.BatchDelay) * time.Second)
log.Debug().Int("batch_message_count", len(msgs)).Msg("Backfilling message batch")
portal.backfill(ctx, user, msgs, forward, shouldMarkAsRead)
user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID)
portal.backfill(user, msgs, forward, shouldMarkAsRead)
}
}
log.Debug().Int("message_count", len(allMsgs)).Msg("Finished backfilling messages in queue entry")
err = user.bridge.DB.HistorySync.DeleteMessages(ctx, user.MXID, conv.ConversationID, allMsgs)
user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID)
err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs)
if err != nil {
log.Err(err).Msg("Failed to delete history sync messages after backfilling")
user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err)
}
if req.TimeStart == nil {
@ -488,11 +399,8 @@ func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTa
// beginning of time.
backfillState.FirstExpectedTimestamp = 0
}
err = backfillState.Upsert(ctx)
if err != nil {
log.Err(err).Msg("Failed to mark backfill state as completed in database")
}
portal.updateBackfillStatus(ctx, backfillState)
backfillState.Upsert()
portal.updateBackfillStatus(backfillState)
}
}
@ -500,13 +408,13 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
if evt == nil || evt.SyncType == nil {
return
}
log := user.zlog.With().
log := user.bridge.ZLog.With().
Str("method", "User.storeHistorySync").
Str("user_id", user.MXID.String()).
Str("sync_type", evt.GetSyncType().String()).
Uint32("chunk_order", evt.GetChunkOrder()).
Uint32("progress", evt.GetProgress()).
Logger()
ctx := log.WithContext(context.TODO())
if evt.GetGlobalSettings() != nil {
log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync")
}
@ -558,7 +466,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues(
user.MXID,
conv.GetId(),
portal.Key,
&portal.Key,
getConversationTimestamp(conv),
conv.GetMuteEndTime(),
conv.GetArchived(),
@ -568,10 +476,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
conv.EphemeralExpiration,
conv.GetMarkedAsUnread(),
conv.GetUnreadCount())
err := historySyncConversation.Upsert(ctx)
if err != nil {
log.Err(err).Msg("Failed to insert history sync conversation into database")
}
historySyncConversation.Upsert()
}
var minTime, maxTime time.Time
@ -616,7 +521,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
Msg("Failed to save historical message")
continue
}
err = message.Insert(ctx)
err = message.Insert()
if err != nil {
log.Error().Err(err).
Int("msg_index", i).
@ -665,20 +570,15 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 {
return convTs
}
func (user *User) EnqueueImmediateBackfills(ctx context.Context, portals []*Portal) {
func (user *User) EnqueueImmediateBackfills(portals []*Portal) {
for priority, portal := range portals {
maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents
initialBackfill := user.bridge.DB.BackfillQueue.NewWithValues(user.MXID, database.BackfillImmediate, priority, portal.Key, nil, maxMessages, maxMessages, 0)
err := initialBackfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert immediate backfill into database")
}
initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0)
initialBackfill.Insert()
}
}
func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Portal) {
func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
numPortals := len(portals)
for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred {
for portalIdx, portal := range portals {
@ -687,36 +587,22 @@ func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Porta
startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo)
startDate = &startDaysAgo
}
backfillMessages := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
err := backfillMessages.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert deferred backfill into database")
}
backfillMessages := user.bridge.DB.Backfill.NewWithValues(
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
backfillMessages.Insert()
}
}
}
func (user *User) EnqueueForwardBackfills(ctx context.Context, portals []*Portal) {
func (user *User) EnqueueForwardBackfills(portals []*Portal) {
for priority, portal := range portals {
lastMsg, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to get last message in chat to enqueue forward backfill")
} else if lastMsg == nil {
lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key)
if lastMsg == nil {
continue
}
backfill := user.bridge.DB.BackfillQueue.NewWithValues(
user.MXID, database.BackfillForward, priority, portal.Key, &lastMsg.Timestamp, -1, -1, 0)
err = backfill.Insert(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("portal_key", portal.Key).
Msg("Failed to insert forward backfill into database")
}
backfill := user.bridge.DB.Backfill.NewWithValues(
user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0)
backfill.Insert()
}
}
@ -733,11 +619,12 @@ func (portal *Portal) deterministicEventID(sender types.JID, messageID types.Mes
}
var (
PortalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType}
)
func (portal *Portal) backfill(ctx context.Context, source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) {
log := zerolog.Ctx(ctx)
func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) *mautrix.RespBeeperBatchSend {
var req mautrix.ReqBeeperBatchSend
var infos []*wrappedInfo
@ -746,10 +633,7 @@ func (portal *Portal) backfill(ctx context.Context, source *User, messages []*wa
req.MarkReadBy = source.MXID
}
log.Info().
Bool("forward", isForward).
Int("message_count", len(messages)).
Msg("Processing history sync message batch")
portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward)
// The messages are ordered newest to oldest, so iterate them in reverse order.
for i := len(messages) - 1; i >= 0; i-- {
webMsg := messages[i]
@ -757,16 +641,11 @@ func (portal *Portal) backfill(ctx context.Context, source *User, messages []*wa
if err != nil {
continue
}
log := log.With().
Str("message_id", msgEvt.Info.ID).
Stringer("message_sender", msgEvt.Info.Sender).
Logger()
ctx := log.WithContext(ctx)
msgType := getMessageType(msgEvt.Message)
if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
if msgType != "ignore" {
log.Debug().Msg("Skipping message with unknown type in backfill")
portal.log.Debugfln("Skipping message %s with unknown type in backfill", msgEvt.Info.ID)
}
continue
}
@ -775,83 +654,85 @@ func (portal *Portal) backfill(ctx context.Context, source *User, messages []*wa
if !existingContact.Found || existingContact.PushName == "" {
changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName())
if err != nil {
log.Err(err).Msg("Failed to save push name from historical message to device store")
source.log.Errorfln("Failed to save push name of %s from historical message in device store: %v", msgEvt.Info.Sender, err)
} else if changed {
log.Debug().Str("push_name", webMsg.GetPushName()).Msg("Got push name from historical message")
source.log.Debugfln("Got push name %s for %s from historical message", webMsg.GetPushName(), msgEvt.Info.Sender)
}
}
}
puppet := portal.getMessagePuppet(ctx, source, &msgEvt.Info)
puppet := portal.getMessagePuppet(source, &msgEvt.Info)
if puppet == nil {
continue
}
converted := portal.convertMessage(ctx, puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
converted := portal.convertMessage(puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
if converted == nil {
log.Debug().Msg("Skipping unsupported message in backfill")
portal.log.Debugfln("Skipping unsupported message %s in backfill", msgEvt.Info.ID)
continue
}
if converted.ReplyTo != nil {
portal.SetReply(ctx, converted.Content, converted.ReplyTo, true)
portal.SetReply(msgEvt.Info.ID, converted.Content, converted.ReplyTo, true)
}
err = portal.appendBatchEvents(ctx, source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
err = portal.appendBatchEvents(source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
if err != nil {
log.Err(err).Msg("Failed to handle message in backfill")
portal.log.Errorfln("Error handling message %s during backfill: %v", msgEvt.Info.ID, err)
}
}
log.Info().Int("event_count", len(req.Events)).Msg("Made Matrix events from messages in batch")
portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events))
if len(req.Events) == 0 {
return
return nil
}
resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &req)
resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &req)
if err != nil {
log.Err(err).Msg("Failed to send batch of messages")
return
}
err = portal.bridge.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
return portal.finishBatch(ctx, resp.EventIDs, infos)
})
if err != nil {
log.Err(err).Msg("Failed to save message batch to database")
return
}
log.Info().Msg("Successfully sent backfill batch")
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(context.TODO(), source, resp.EventIDs, infos)
portal.log.Errorln("Error batch sending messages:", err)
return nil
} else {
txn, err := portal.bridge.DB.Begin()
if err != nil {
portal.log.Errorln("Failed to start transaction to save batch messages:", err)
return nil
}
portal.finishBatch(txn, resp.EventIDs, infos)
err = txn.Commit()
if err != nil {
portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
return nil
}
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
go portal.requestMediaRetries(source, resp.EventIDs, infos)
}
return resp
}
}
func (portal *Portal) requestMediaRetries(ctx context.Context, source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
for i, info := range infos {
if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil {
switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod {
case config.MediaRequestMethodImmediate:
err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey)
if err != nil {
portal.zlog.Err(err).Str("message_id", info.ID).Msg("Failed to send post-backfill media retry request")
portal.log.Warnfln("Failed to send post-backfill media retry request for %s: %v", info.ID, err)
} else {
portal.zlog.Debug().Str("message_id", info.ID).Msg("Sent post-backfill media retry request")
portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID)
}
case config.MediaRequestMethodLocalTime:
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, portal.Key, eventIDs[i], info.MediaKey)
err := req.Upsert(ctx)
if err != nil {
portal.zlog.Err(err).
Stringer("event_id", eventIDs[i]).
Msg("Failed to upsert media backfill request")
}
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey)
req.Upsert()
}
}
}
}
func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
if portal.bridge.Config.Bridge.CaptionInMessage {
converted.MergeCaption()
}
mainEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
if err != nil {
return err
}
@ -869,7 +750,7 @@ func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, conve
ExpiresIn: converted.ExpiresIn,
}
if converted.Caption != nil {
captionEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
if err != nil {
return err
}
@ -881,7 +762,7 @@ func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, conve
}
if converted.MultiEvent != nil {
for i, subEvtContent := range converted.MultiEvent {
subEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
if err != nil {
return err
}
@ -890,7 +771,7 @@ func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, conve
}
}
for _, reaction := range raw.GetReactions() {
reactionEvent, reactionInfo := portal.wrapBatchReaction(ctx, source, reaction, mainEvt.ID, info.Timestamp)
reactionEvent, reactionInfo := portal.wrapBatchReaction(source, reaction, mainEvt.ID, info.Timestamp)
if reactionEvent != nil {
*eventsArray = append(*eventsArray, reactionEvent)
*infoArray = append(*infoArray, &wrappedInfo{
@ -904,7 +785,7 @@ func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, conve
return nil
}
func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
var senderJID types.JID
if reaction.GetKey().GetFromMe() {
senderJID = source.JID.ToNonAD()
@ -926,7 +807,7 @@ func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, react
ID: reaction.GetKey().GetId(),
Timestamp: mainEventTS,
}
puppet := portal.getMessagePuppet(ctx, source, reactionInfo)
puppet := portal.getMessagePuppet(source, reactionInfo)
if puppet == nil {
return
}
@ -953,12 +834,12 @@ func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, react
return
}
func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
wrappedContent := event.Content{
Parsed: content,
Raw: extraContent,
}
newEventType, err := portal.encrypt(ctx, intent, &wrappedContent, eventType)
newEventType, err := portal.encrypt(intent, &wrappedContent, eventType)
if err != nil {
return nil, err
}
@ -972,37 +853,37 @@ func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInf
}, nil
}
func (portal *Portal) finishBatch(ctx context.Context, eventIDs []id.EventID, infos []*wrappedInfo) error {
func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, infos []*wrappedInfo) {
for i, info := range infos {
if info == nil {
continue
}
eventID := eventIDs[i]
portal.markHandled(ctx, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
if info.Type == database.MsgReaction {
portal.upsertReaction(ctx, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
}
if info.ExpiresIn > 0 {
portal.MarkDisappearing(ctx, eventID, info.ExpiresIn, info.ExpirationStart)
portal.MarkDisappearing(txn, eventID, info.ExpiresIn, info.ExpirationStart)
}
}
return nil
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
}
func (portal *Portal) updateBackfillStatus(ctx context.Context, backfillState *database.BackfillState) {
func (portal *Portal) updateBackfillStatus(backfillState *database.BackfillState) {
backfillStatus := "backfilling"
if backfillState.BackfillComplete {
backfillStatus = "complete"
}
_, err := portal.bridge.Bot.SendStateEvent(ctx, portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
_, err := portal.bridge.Bot.SendStateEvent(portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
"status": backfillStatus,
"first_timestamp": backfillState.FirstExpectedTimestamp * 1000,
})
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill status event to room")
portal.log.Errorln("Error sending backfill status event:", err)
}
}

62
main.go
View file

@ -17,7 +17,6 @@
package main
import (
"context"
_ "embed"
"net/http"
"net/url"
@ -27,18 +26,15 @@ import (
"sync"
"time"
"github.com/rs/zerolog"
waLog "go.mau.fi/whatsmeow/util/log"
"google.golang.org/protobuf/proto"
"go.mau.fi/util/configupgrade"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/util/configupgrade"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/bridge/status"
@ -95,7 +91,7 @@ func (br *WABridge) Init() {
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
Analytics.log = br.ZLog.With().Str("component", "analytics").Logger()
Analytics.log = br.Log.Sub("Analytics")
Analytics.url = (&url.URL{
Scheme: "https",
Host: br.Config.Analytics.Host,
@ -104,20 +100,23 @@ func (br *WABridge) Init() {
Analytics.key = br.Config.Analytics.Token
Analytics.userID = br.Config.Analytics.UserID
if Analytics.IsEnabled() {
Analytics.log.Info().Str("override_user_id", Analytics.userID).Msg("Analytics metrics are enabled")
Analytics.log.Infoln("Analytics metrics are enabled")
if Analytics.userID != "" {
Analytics.log.Infoln("Overriding analytics user_id with %v", Analytics.userID)
}
}
br.DB = database.New(br.Bridge.DB)
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), waLog.Zerolog(br.ZLog.With().Str("db_section", "whatsmeow").Logger()))
br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database"))
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), &waLogger{br.Log.Sub("Database").Sub("WhatsApp")})
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
ss := br.Config.Bridge.Provisioning.SharedSecret
if len(ss) > 0 && ss != "disable" {
br.Provisioning = &ProvisioningAPI{bridge: br, log: br.ZLog.With().Str("component", "provisioning").Logger()}
br.Provisioning = &ProvisioningAPI{bridge: br}
}
br.Formatter = NewFormatter(br)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.ZLog.With().Str("component", "metrics").Logger(), br.DB)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
@ -149,10 +148,11 @@ func (br *WABridge) Init() {
func (br *WABridge) Start() {
err := br.WAContainer.Upgrade()
if err != nil {
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade whatsmeow database")
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
os.Exit(15)
}
if br.Provisioning != nil {
br.Log.Debugln("Initializing provisioning API")
br.Provisioning.Init()
}
go br.CheckWhatsAppUpdate()
@ -166,40 +166,30 @@ func (br *WABridge) Start() {
}
func (br *WABridge) CheckWhatsAppUpdate() {
br.ZLog.Debug().Msg("Checking for WhatsApp web update")
br.Log.Debugfln("Checking for WhatsApp web update")
resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to check for WhatsApp web update")
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
return
}
if store.GetWAVersion() == resp.ParsedVersion {
br.ZLog.Debug().Msg("Bridge is using latest WhatsApp web protocol")
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
if resp.IsBelowHard || resp.IsBroken {
br.ZLog.Warn().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore")
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else if resp.IsBelowSoft {
br.ZLog.Info().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else {
br.ZLog.Debug().
Stringer("latest_version", resp.ParsedVersion).
Stringer("current_version", store.GetWAVersion()).
Msg("Bridge is using outdated WhatsApp web protocol")
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
}
} else {
br.ZLog.Debug().Msg("Bridge is using newer than latest WhatsApp web protocol")
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
}
}
func (br *WABridge) Loop() {
ctx := br.ZLog.With().Str("action", "background loop").Logger().WithContext(context.TODO())
for {
br.SleepAndDeleteUpcoming(ctx)
br.SleepAndDeleteUpcoming()
time.Sleep(1 * time.Hour)
br.WarnUsersAboutDisconnection()
}
@ -209,14 +199,14 @@ func (br *WABridge) WarnUsersAboutDisconnection() {
br.usersLock.Lock()
for _, user := range br.usersByUsername {
if user.IsConnected() && !user.PhoneRecentlySeen(true) {
go user.sendPhoneOfflineWarning(context.TODO())
go user.sendPhoneOfflineWarning()
}
}
br.usersLock.Unlock()
}
func (br *WABridge) StartUsers() {
br.ZLog.Debug().Msg("Starting users")
br.Log.Debugln("Starting users")
foundAnySessions := false
for _, user := range br.GetAllUsers() {
if !user.JID.IsEmpty() {
@ -227,13 +217,13 @@ func (br *WABridge) StartUsers() {
if !foundAnySessions {
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil))
}
br.ZLog.Debug().Msg("Starting custom puppets")
br.Log.Debugln("Starting custom puppets")
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
go func(puppet *Puppet) {
puppet.zlog.Debug().Stringer("custom_mxid", puppet.CustomMXID).Msg("Starting double puppet")
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
err := puppet.StartCustomMXID(true)
if err != nil {
puppet.zlog.Err(err).Stringer("custom_mxid", puppet.CustomMXID).Msg("Failed to start double puppet")
puppet.log.Errorln("Failed to start custom puppet:", err)
}
}(loopuppet)
}
@ -245,7 +235,7 @@ func (br *WABridge) Stop() {
if user.Client == nil {
continue
}
user.zlog.Debug().Msg("Disconnecting user")
br.Log.Debugln("Disconnecting", user.MXID)
user.Client.Disconnect()
close(user.historySyncs)
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,8 @@
package main
import (
"context"
"fmt"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix"
@ -37,89 +35,77 @@ func (br *WABridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.User,
puppet := brGhost.(*Puppet)
key := database.NewPortalKey(puppet.JID, inviter.JID)
portal := br.GetPortalByJID(key)
log := br.ZLog.With().
Str("action", "create private portal").
Stringer("target_room_id", roomID).
Stringer("inviter_mxid", inviter.MXID).
Stringer("invitee_jid", puppet.JID).
Logger()
ctx := log.WithContext(context.TODO())
if len(portal.MXID) == 0 {
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
return
}
ok := portal.ensureUserInvited(ctx, inviter)
ok := portal.ensureUserInvited(inviter)
if !ok {
log.Warn().Msg("Failed to invite user to existing private chat portal. Redirecting portal to new room...")
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
br.Log.Warnfln("Failed to invite %s to existing private chat portal %s with %s. Redirecting portal to new room...", inviter.MXID, portal.MXID, puppet.JID)
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
return
}
intent := puppet.DefaultIntent()
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Config.Homeserver.Domain).MatrixToURL())
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%[1]s](https://matrix.to/#/%[1]s)", portal.MXID)
errorContent := format.RenderMarkdown(errorMessage, true, false)
_, _ = intent.SendMessageEvent(ctx, roomID, event.EventMessage, errorContent)
log.Debug().Msg("Leaving private chat room from invite as we already have chat with the user")
_, _ = intent.LeaveRoom(ctx, roomID)
_, _ = intent.SendMessageEvent(roomID, event.EventMessage, errorContent)
br.Log.Debugfln("Leaving private chat room %s as %s after accepting invite from %s as we already have chat with the user", roomID, puppet.MXID, inviter.MXID)
_, _ = intent.LeaveRoom(roomID)
}
func (br *WABridge) createPrivatePortalFromInvite(ctx context.Context, roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
log := zerolog.Ctx(ctx)
func (br *WABridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
// TODO check if room is already encrypted
var existingEncryption event.EncryptionEventContent
var encryptionEnabled bool
err := portal.MainIntent().StateEvent(ctx, roomID, event.StateEncryption, "", &existingEncryption)
err := portal.MainIntent().StateEvent(roomID, event.StateEncryption, "", &existingEncryption)
if err != nil {
log.Err(err).Msg("Failed to check if encryption is enabled")
portal.log.Warnfln("Failed to check if encryption is enabled in private chat room %s", roomID)
} else {
encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1
}
portal.MXID = roomID
portal.updateLogger()
portal.Topic = PrivateChatTopic
portal.Name = puppet.Displayname
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
log.Info().Msg("Created private chat portal from invite")
portal.log.Infofln("Created private chat portal in %s after invite from %s", roomID, inviter.MXID)
intent := puppet.DefaultIntent()
if br.Config.Bridge.Encryption.Default || encryptionEnabled {
_, err = intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
if err != nil {
log.Err(err).Msg("Failed to invite bridge bot to enable e2be")
portal.log.Warnln("Failed to invite bridge bot to enable e2be:", err)
}
err = br.Bot.EnsureJoined(ctx, roomID)
err = br.Bot.EnsureJoined(roomID)
if err != nil {
log.Err(err).Msg("Failed to join as bridge bot to enable e2be")
portal.log.Warnln("Failed to join as bridge bot to enable e2be:", err)
}
if !encryptionEnabled {
_, err = intent.SendStateEvent(ctx, roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
_, err = intent.SendStateEvent(roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
if err != nil {
log.Err(err).Msg("Failed to enable e2be")
portal.log.Warnln("Failed to enable e2be:", err)
}
}
br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin)
portal.Encrypted = true
}
_, _ = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, portal.Topic)
_, _ = portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic)
if portal.shouldSetDMRoomMetadata() {
_, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, portal.Name)
_, err = portal.MainIntent().SetRoomName(portal.MXID, portal.Name)
portal.NameSet = err == nil
_, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL)
_, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL)
portal.AvatarSet = err == nil
}
err = portal.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal to database after creating from invite")
}
portal.UpdateBridgeInfo(ctx)
_, _ = intent.SendNotice(ctx, roomID, "Private chat portal created")
portal.Update(nil)
portal.UpdateBridgeInfo()
_, _ = intent.SendNotice(roomID, "Private chat portal created")
}
func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) {
func (br *WABridge) HandlePresence(evt *event.Event) {
user := br.GetUserByMXIDIfExists(evt.Sender)
if user == nil || !user.IsLoggedIn() {
return
@ -133,15 +119,15 @@ func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) {
presence := types.PresenceAvailable
if evt.Content.AsPresence().Presence != event.PresenceOnline {
presence = types.PresenceUnavailable
user.zlog.Debug().Msg("Marking offline")
user.log.Debugln("Marking offline")
} else {
user.zlog.Debug().Msg("Marking online")
user.log.Debugln("Marking online")
}
user.lastPresence = presence
if user.Client.Store.PushName != "" {
err := user.Client.SendPresence(presence)
if err != nil {
user.zlog.Err(err).Msg("Failed to set presence")
user.log.Warnln("Failed to set presence:", err)
}
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -23,7 +23,7 @@ import (
"sync"
"time"
"github.com/rs/zerolog"
log "maunium.net/go/maulogger/v2"
"go.mau.fi/whatsmeow"
@ -123,7 +123,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
}
}
func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, err error, confirmed bool, editID id.EventID) id.EventID {
func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType string, confirmed bool, editID id.EventID) id.EventID {
if !portal.bridge.Config.Bridge.MessageErrorNotices {
return ""
}
@ -131,21 +131,6 @@ func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, er
if confirmed {
certainty = "was not"
}
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err)
if errors.Is(err, errMessageTakingLong) {
msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType)
@ -159,15 +144,15 @@ func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, er
} else {
content.SetReply(evt)
}
resp, err := portal.sendMainIntentMessage(ctx, content)
resp, err := portal.sendMainIntentMessage(content)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to send bridging error message")
portal.log.Warnfln("Failed to send bridging error message:", err)
return ""
}
return resp.EventID
}
func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
if !portal.bridge.Config.Bridge.MessageStatusEvents {
return
}
@ -194,56 +179,75 @@ func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.E
content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err)
content.Error = err.Error()
}
_, err = intent.SendMessageEvent(ctx, portal.MXID, event.BeeperMessageStatus, &content)
_, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to send message status event")
portal.log.Warnln("Failed to send message status event:", err)
}
}
func (portal *Portal) sendDeliveryReceipt(ctx context.Context, eventID id.EventID) {
func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) {
if portal.bridge.Config.Bridge.DeliveryReceipts {
err := portal.bridge.Bot.SendReceipt(ctx, portal.MXID, eventID, event.ReceiptTypeRead, nil)
err := portal.bridge.Bot.SendReceipt(portal.MXID, eventID, event.ReceiptTypeRead, nil)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to mark message as read by bot (Matrix-side delivery receipt)")
portal.log.Debugfln("Failed to send delivery receipt for %s: %v", eventID, err)
}
}
}
func (portal *Portal) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, ms *metricSender) {
func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string, ms *metricSender) {
var msgType string
switch evt.Type {
case event.EventMessage:
msgType = "message"
case event.EventReaction:
msgType = "reaction"
case event.EventRedaction:
msgType = "redaction"
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
msgType = "poll response"
case TypeMSC3381PollStart:
msgType = "poll start"
default:
msgType = "unknown event"
}
evtDescription := evt.ID.String()
if evt.Type == event.EventRedaction {
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
}
origEvtID := evt.ID
if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil {
origEvtID = retryMeta.OriginalEventID
}
if err != nil {
level := zerolog.ErrorLevel
level := log.LevelError
if part == "Ignoring" {
level = zerolog.DebugLevel
level = log.LevelDebug
}
zerolog.Ctx(ctx).WithLevel(level).Err(err).Msg(part + " Matrix event")
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode)
portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum())
if sendNotice {
ms.setNoticeID(portal.sendErrorMessage(ctx, evt, err, isCertain, ms.getNoticeID()))
ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID()))
}
portal.sendStatusEvent(ctx, origEvtID, evt.ID, err, nil)
portal.sendStatusEvent(origEvtID, evt.ID, err, nil)
} else {
zerolog.Ctx(ctx).Debug().Msg("Successfully handled Matrix event")
portal.sendDeliveryReceipt(ctx, evt.ID)
portal.log.Debugfln("Handled Matrix %s %s", msgType, evtDescription)
portal.sendDeliveryReceipt(evt.ID)
portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum())
var deliveredTo *[]id.UserID
if portal.IsPrivateChat() {
deliveredTo = &[]id.UserID{}
}
portal.sendStatusEvent(ctx, origEvtID, evt.ID, nil, deliveredTo)
portal.sendStatusEvent(origEvtID, evt.ID, nil, deliveredTo)
if prevNotice := ms.popNoticeID(); prevNotice != "" {
_, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, prevNotice, mautrix.ReqRedact{
_, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{
Reason: "error resolved",
})
}
}
if ms != nil {
zerolog.Ctx(ctx).Debug().Object("timings", ms.timings).Msg("Matrix event timings")
portal.log.Debugfln("Timings for %s: %s", evt.ID, ms.timings.String())
}
}
@ -260,16 +264,47 @@ type messageTimings struct {
totalSend time.Duration
}
func (mt *messageTimings) MarshalZerologObject(e *zerolog.Event) {
e.Dur("init_receive", mt.initReceive).
Dur("decrypt", mt.decrypt).
Dur("implicit_rr", mt.implicitRR).
Dur("portal_queue", mt.portalQueue).
Dur("total_receive", mt.totalReceive).
Dur("preproc", mt.preproc).
Dur("convert", mt.convert).
Object("whatsmeow", mt.whatsmeow).
Dur("total_send", mt.totalSend)
func niceRound(dur time.Duration) time.Duration {
switch {
case dur < time.Millisecond:
return dur
case dur < time.Second:
return dur.Round(100 * time.Microsecond)
default:
return dur.Round(time.Millisecond)
}
}
func (mt *messageTimings) String() string {
mt.initReceive = niceRound(mt.initReceive)
mt.decrypt = niceRound(mt.decrypt)
mt.portalQueue = niceRound(mt.portalQueue)
mt.totalReceive = niceRound(mt.totalReceive)
mt.implicitRR = niceRound(mt.implicitRR)
mt.preproc = niceRound(mt.preproc)
mt.convert = niceRound(mt.convert)
mt.whatsmeow.Queue = niceRound(mt.whatsmeow.Queue)
mt.whatsmeow.Marshal = niceRound(mt.whatsmeow.Marshal)
mt.whatsmeow.GetParticipants = niceRound(mt.whatsmeow.GetParticipants)
mt.whatsmeow.GetDevices = niceRound(mt.whatsmeow.GetDevices)
mt.whatsmeow.GroupEncrypt = niceRound(mt.whatsmeow.GroupEncrypt)
mt.whatsmeow.PeerEncrypt = niceRound(mt.whatsmeow.PeerEncrypt)
mt.whatsmeow.Send = niceRound(mt.whatsmeow.Send)
mt.whatsmeow.Resp = niceRound(mt.whatsmeow.Resp)
mt.whatsmeow.Retry = niceRound(mt.whatsmeow.Retry)
mt.totalSend = niceRound(mt.totalSend)
whatsmeowTimings := "N/A"
if mt.totalSend > 0 {
format := "queue: %[1]s, marshal: %[2]s, ske: %[3]s, pcp: %[4]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
if mt.whatsmeow.GetParticipants == 0 && mt.whatsmeow.GroupEncrypt == 0 {
format = "queue: %[1]s, marshal: %[2]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
}
if mt.whatsmeow.Retry > 0 {
format += ", retry: %[9]s"
}
whatsmeowTimings = fmt.Sprintf(format, mt.whatsmeow.Queue, mt.whatsmeow.Marshal, mt.whatsmeow.GroupEncrypt, mt.whatsmeow.GetParticipants, mt.whatsmeow.GetDevices, mt.whatsmeow.PeerEncrypt, mt.whatsmeow.Send, mt.whatsmeow.Resp, mt.whatsmeow.Retry)
}
return fmt.Sprintf("BRIDGE: receive: %s, decrypt: %s, queue: %s, total hs->portal: %s, implicit rr: %s -- PORTAL: preprocess: %s, convert: %s, total send: %s -- WHATSMEOW: %s", mt.initReceive, mt.decrypt, mt.implicitRR, mt.portalQueue, mt.totalReceive, mt.preproc, mt.convert, mt.totalSend, whatsmeowTimings)
}
type metricSender struct {
@ -310,13 +345,13 @@ func (ms *metricSender) setNoticeID(evtID id.EventID) {
}
}
func (ms *metricSender) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, completed bool) {
func (ms *metricSender) sendMessageMetrics(evt *event.Event, err error, part string, completed bool) {
ms.lock.Lock()
defer ms.lock.Unlock()
if !completed && ms.completed {
return
}
ms.portal.sendMessageMetrics(ctx, evt, err, part, ms)
ms.portal.sendMessageMetrics(evt, err, part, ms)
ms.retryNum++
ms.completed = completed
}

View file

@ -18,7 +18,6 @@ package main
import (
"context"
"errors"
"net/http"
"runtime/debug"
"strconv"
@ -28,7 +27,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
log "maunium.net/go/maulogger/v2"
"go.mau.fi/whatsmeow/types"
@ -41,7 +40,7 @@ import (
type MetricsHandler struct {
db *database.Database
server *http.Server
log zerolog.Logger
log log.Logger
running bool
ctx context.Context
@ -71,7 +70,7 @@ type MetricsHandler struct {
loggedInStateLock sync.Mutex
}
func NewMetricsHandler(address string, log zerolog.Logger, db *database.Database) *MetricsHandler {
func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler {
portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "whatsapp_portals_total",
Help: "Number of portal rooms on Matrix",
@ -233,31 +232,31 @@ func (mh *MetricsHandler) TrackConnectionState(jid types.JID, connected bool) {
func (mh *MetricsHandler) updateStats() {
start := time.Now()
var puppetCount int
err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
if err != nil {
mh.log.Err(err).Msg("Failed to scan number of puppets")
mh.log.Warnln("Failed to scan number of puppets:", err)
} else {
mh.puppetCount.Set(float64(puppetCount))
}
var userCount int
err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
if err != nil {
mh.log.Err(err).Msg("Failed to scan number of users")
mh.log.Warnln("Failed to scan number of users:", err)
} else {
mh.userCount.Set(float64(userCount))
}
var messageCount int
err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
if err != nil {
mh.log.Err(err).Msg("Failed to scan number of messages")
mh.log.Warnln("Failed to scan number of messages:", err)
} else {
mh.messageCount.Set(float64(messageCount))
}
var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int
err = mh.db.QueryRow(mh.ctx, `
err = mh.db.QueryRowContext(mh.ctx, `
SELECT
COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals,
COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals,
@ -266,7 +265,7 @@ func (mh *MetricsHandler) updateStats() {
FROM portal WHERE mxid<>''
`).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount)
if err != nil {
mh.log.Err(err).Msg("Failed to scan number of portals")
mh.log.Warnln("Failed to scan number of portals:", err)
} else {
mh.encryptedGroupCount.Set(float64(encryptedGroupCount))
mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount))
@ -280,10 +279,7 @@ func (mh *MetricsHandler) startUpdatingStats() {
defer func() {
err := recover()
if err != nil {
mh.log.WithLevel(zerolog.PanicLevel).
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
Interface(zerolog.ErrorFieldName, err).
Msg("Panic in metric updater")
mh.log.Fatalfln("Panic in metric updater: %v\n%s", err, string(debug.Stack()))
}
}()
ticker := time.Tick(10 * time.Second)
@ -303,8 +299,8 @@ func (mh *MetricsHandler) Start() {
go mh.startUpdatingStats()
err := mh.server.ListenAndServe()
mh.running = false
if err != nil && !errors.Is(err, http.ErrServerClosed) {
mh.log.Err(err).Msg("Error in metrics listener")
if err != nil && err != http.ErrServerClosed {
mh.log.Fatalln("Error in metrics listener:", err)
}
}
@ -315,6 +311,6 @@ func (mh *MetricsHandler) Stop() {
mh.stopRecorder()
err := mh.server.Close()
if err != nil {
mh.log.Err(err).Msg("Failed to close metrics listener")
mh.log.Errorln("Error closing metrics listener:", err)
}
}

1991
portal.go

File diff suppressed because it is too large Load diff

View file

@ -17,40 +17,41 @@
package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"strings"
"time"
"github.com/beeper/libserv/pkg/requestlog"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"go.mau.fi/whatsmeow/appstate"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/id"
)
type ProvisioningAPI struct {
bridge *WABridge
log zerolog.Logger
log log.Logger
}
func (prov *ProvisioningAPI) Init() {
prov.log.Debug().Str("base_path", prov.bridge.Config.Bridge.Provisioning.Prefix).Msg("Enabling provisioning API")
prov.log = prov.bridge.Log.Sub("Provisioning")
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
r.Use(hlog.NewHandler(prov.log))
r.Use(requestlog.AccessLogger(true))
r.Use(prov.AuthMiddleware)
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@ -72,7 +73,7 @@ func (prov *ProvisioningAPI) Init() {
prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost)
if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
prov.log.Debugln("Enabling debug API at /debug")
r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter()
r.Use(prov.AuthMiddleware)
r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
@ -82,6 +83,26 @@ func (prov *ProvisioningAPI) Init() {
r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost)
}
type responseWrap struct {
http.ResponseWriter
statusCode int
}
var _ http.Hijacker = (*responseWrap)(nil)
func (rw *responseWrap) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}
func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("response does not implement http.Hijacker")
}
return hijacker.Hijack()
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
@ -98,7 +119,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
auth = auth[len("Bearer "):]
}
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
hlog.FromRequest(r).Debug().Msg("Authentication token does not match shared secret")
prov.log.Infof("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Authentication token does not match shared secret",
"errcode": "M_FORBIDDEN",
@ -107,12 +128,11 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
}
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
if user != nil {
hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("user_id", user.MXID)
})
}
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
start := time.Now()
wWrap := &responseWrap{w, 200}
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user)))
duration := time.Now().Sub(start).Seconds()
prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode)
})
}
@ -137,7 +157,7 @@ func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Reques
return
}
user.DeleteConnection()
user.DeleteSession(r.Context())
user.DeleteSession()
jsonResponse(w, http.StatusOK, Response{true, "Session information purged"})
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
}
@ -225,7 +245,7 @@ func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request
ErrCode: "no session",
})
} else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to fetch all contacts")
prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err)
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching contact list",
ErrCode: "failed to get contacts",
@ -262,7 +282,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
if r.Method == http.MethodPost {
err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true")
if err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to resync groups")
prov.log.Errorfln("Failed to resync %s's groups: %v", user.MXID, err)
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while resyncing groups",
ErrCode: "failed to sync groups",
@ -271,7 +291,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
}
}
if groups, err := user.getCachedGroupList(); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to fetch group list")
prov.log.Errorfln("Failed to fetch %s's groups: %v", user.MXID, err)
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Internal server error while fetching group list",
ErrCode: "failed to get groups",
@ -348,17 +368,17 @@ func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
// resolveIdentifier already responded with an error
return
}
portal, puppet, justCreated, err := user.StartPM(r.Context(), jid, "provisioning API PM")
portal, puppet, justCreated, err := user.StartPM(jid, "provisioning API PM")
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
}
statusCode := http.StatusOK
status := http.StatusOK
if justCreated {
statusCode = http.StatusCreated
status = http.StatusCreated
}
jsonResponse(w, statusCode, PortalInfo{
jsonResponse(w, status, PortalInfo{
RoomID: portal.MXID,
OtherUser: &OtherUserInfo{
JID: puppet.JID,
@ -429,30 +449,29 @@ func (prov *ProvisioningAPI) OpenGroup(w http.ResponseWriter, r *http.Request) {
ErrCode: "invalid group id",
})
} else if info, err := user.Client.GetGroupInfo(jid); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info by JID")
// TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup)
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to get group info: %v", err),
ErrCode: "error getting group info",
})
} else {
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Importing group chat for user")
prov.log.Debugln("Importing", jid, "for", user.MXID)
portal := user.GetPortalByJID(info.JID)
statusCode := http.StatusOK
status := http.StatusOK
if len(portal.MXID) == 0 {
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
err = portal.CreateMatrixRoom(user, info, nil, true, true)
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
return
}
statusCode = http.StatusCreated
status = http.StatusCreated
}
jsonResponse(w, statusCode, PortalInfo{
jsonResponse(w, status, PortalInfo{
RoomID: portal.MXID,
GroupInfo: info,
JustCreated: statusCode == http.StatusCreated,
JustCreated: status == http.StatusCreated,
})
}
}
@ -476,7 +495,6 @@ func (prov *ProvisioningAPI) resolveGroupInvite(w http.ResponseWriter, r *http.R
ErrCode: "invalid invite link",
})
} else {
hlog.FromRequest(r).Err(err).Msg("Failed to get group info from link")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to fetch group info with link: %v", err),
ErrCode: "error getting group info",
@ -512,30 +530,29 @@ func (prov *ProvisioningAPI) JoinGroup(w http.ResponseWriter, r *http.Request) {
}()
inviteCode, _ := mux.Vars(r)["inviteCode"]
if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil {
hlog.FromRequest(r).Err(err).Msg("Failed to join group")
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to join group: %v", err),
ErrCode: "error joining group",
})
} else {
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Successfully joined group")
prov.log.Debugln(user.MXID, "successfully joined group", jid)
portal := user.GetPortalByJID(jid)
statusCode := http.StatusOK
status := http.StatusOK
if len(portal.MXID) == 0 {
time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
err = portal.CreateMatrixRoom(user, info, nil, true, true)
if err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Failed to create portal: %v", err),
})
return
}
statusCode = http.StatusCreated
status = http.StatusCreated
}
jsonResponse(w, statusCode, PortalInfo{
jsonResponse(w, status, PortalInfo{
RoomID: portal.MXID,
GroupInfo: info,
JustCreated: statusCode == http.StatusCreated,
JustCreated: status == http.StatusCreated,
})
}
}
@ -599,7 +616,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
} else {
err := user.Client.Logout()
if err != nil {
hlog.FromRequest(r).Err(err).Msg("Unknown error while logging out")
user.log.Warnln("Error while logging out:", err)
if !force {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: fmt.Sprintf("Unknown error while logging out: %v", err),
@ -615,7 +632,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
user.bridge.Metrics.TrackConnectionState(user.JID, false)
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
user.DeleteSession(r.Context())
user.DeleteSession()
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
}
@ -629,17 +646,16 @@ var upgrader = websocket.Upgrader{
func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
log := hlog.FromRequest(r)
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Err(err).Msg("Failed to upgrade connection to websocket")
prov.log.Errorln("Failed to upgrade connection to websocket:", err)
return
}
defer func() {
err := c.Close()
if err != nil {
log.Debug().Err(err).Msg("Error closing websocket")
user.log.Debugln("Error closing websocket:", err)
}
}()
@ -654,26 +670,23 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
}()
ctx, cancel := context.WithCancel(context.Background())
c.SetCloseHandler(func(code int, text string) error {
log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login")
user.log.Debugfln("Login websocket closed (%d), cancelling login", code)
cancel()
return nil
})
if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" {
log.Debug().Str("timezone", userTimezone).Msg("Updating user timezone")
user.log.Debug("Setting timezone to %s", userTimezone)
user.Timezone = userTimezone
err = user.Update(r.Context())
if err != nil {
log.Err(err).Msg("Failed to save user after updating timezone")
}
user.Update()
} else {
log.Debug().Msg("No timezone provided in request")
user.log.Debug("No timezone provided in request")
}
qrChan, err := user.Login(ctx)
expiryTime := time.Now().Add(160 * time.Second)
if err != nil {
log.Err(err).Msg("Failed to log in via provisioning API")
user.log.Errorln("Failed to log in from provisioning API:", err)
if errors.Is(err, ErrAlreadyLoggedIn) {
go user.Connect()
_ = c.WriteJSON(Error{
@ -691,7 +704,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
if phoneNum != "" {
pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)")
if err != nil {
log.Err(err).Msg("Failed to start phone code login")
user.zlog.Err(err).Msg("Failed to start phone code login")
_ = c.WriteJSON(Error{
Error: "Failed to request pairing code",
ErrCode: "code error",
@ -699,7 +712,6 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
go user.DeleteConnection()
return
} else {
log.Debug().Msg("Started phone number login")
_ = c.WriteJSON(map[string]any{
"pairing_code": pairingCode,
"timeout": int(time.Until(expiryTime).Seconds()),
@ -707,7 +719,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
}
}
log.Debug().Msg("Started login via provisioning API")
user.log.Debugln("Started login via provisioning API")
Analytics.Track(user.MXID, "$login_start")
for {
@ -716,7 +728,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
switch evt.Event {
case whatsmeow.QRChannelSuccess.Event:
jid := user.Client.Store.ID
log.Debug().Stringer("jid", jid).Msg("Successful login via provisioning API")
user.log.Debugln("Successful login as", jid, "via provisioning API")
Analytics.Track(user.MXID, "$login_success")
_ = c.WriteJSON(map[string]interface{}{
"success": true,
@ -725,7 +737,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
"platform": user.Client.Store.Platform,
})
case whatsmeow.QRChannelTimeout.Event:
log.Debug().Msg("Login via provisioning API timed out")
user.log.Debugln("Login via provisioning API timed out")
errCode := "login timed out"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{
@ -733,7 +745,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode,
})
case whatsmeow.QRChannelErrUnexpectedEvent.Event:
log.Debug().Msg("Login via provisioning API failed due to unexpected event")
user.log.Debugln("Login via provisioning API failed due to unexpected event")
errCode := "unexpected event"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{
@ -741,7 +753,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
ErrCode: errCode,
})
case whatsmeow.QRChannelClientOutdated.Event:
log.Debug().Msg("Login via provisioning API failed due to outdated client")
user.log.Debugln("Login via provisioning API failed due to outdated client")
errCode := "bridge outdated"
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{

137
puppet.go
View file

@ -17,15 +17,15 @@
package main
import (
"context"
"fmt"
"regexp"
"sync"
"time"
"github.com/rs/zerolog"
"go.mau.fi/whatsmeow/types"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
@ -59,7 +59,6 @@ func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
}
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
ctx := context.TODO()
jid = jid.ToNonAD()
if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer
@ -70,19 +69,11 @@ func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
defer br.puppetsLock.Unlock()
puppet, ok := br.puppets[jid]
if !ok {
dbPuppet, err := br.DB.Puppet.Get(ctx, jid)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get puppet from database")
return nil
}
dbPuppet := br.DB.Puppet.Get(jid)
if dbPuppet == nil {
dbPuppet = br.DB.Puppet.New()
dbPuppet.JID = jid
err = dbPuppet.Insert(ctx)
if err != nil {
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to insert new puppet to database")
return nil
}
dbPuppet.Insert()
}
puppet = br.NewPuppet(dbPuppet)
br.puppets[puppet.JID] = puppet
@ -98,10 +89,7 @@ func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
defer br.puppetsLock.Unlock()
puppet, ok := br.puppetsByCustomMXID[mxid]
if !ok {
dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid)
if err != nil {
br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get puppet by custom mxid from database")
}
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
if dbPuppet == nil {
return nil
}
@ -149,18 +137,14 @@ func (puppet *Puppet) GetMXID() id.UserID {
}
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID(context.TODO()))
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
}
func (br *WABridge) GetAllPuppets() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll(context.TODO()))
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
}
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet, err error) []*Puppet {
if err != nil {
br.ZLog.Err(err).Msg("Error getting puppets from database")
return nil
}
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
output := make([]*Puppet, len(dbPuppets))
@ -191,7 +175,7 @@ func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{
Puppet: dbPuppet,
bridge: br,
zlog: br.ZLog.With().Stringer("puppet_jid", dbPuppet.JID).Logger(),
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
MXID: br.FormatPuppetMXID(dbPuppet.JID),
}
@ -201,7 +185,7 @@ type Puppet struct {
*database.Puppet
bridge *WABridge
zlog zerolog.Logger
log log.Logger
typingIn id.RoomID
typingAt time.Time
@ -239,47 +223,47 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
return puppet.bridge.AS.Intent(puppet.MXID)
}
func (puppet *Puppet) UpdateAvatar(ctx context.Context, source *User, forcePortalSync bool) bool {
changed := source.updateAvatar(ctx, puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.DefaultIntent())
func (puppet *Puppet) UpdateAvatar(source *User, forcePortalSync bool) bool {
changed := source.updateAvatar(puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.log, puppet.DefaultIntent())
if !changed || puppet.Avatar == "unauthorized" {
if forcePortalSync {
go puppet.updatePortalAvatar(ctx)
go puppet.updatePortalAvatar()
}
return changed
}
err := puppet.DefaultIntent().SetAvatarURL(ctx, puppet.AvatarURL)
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar from puppet")
puppet.log.Warnln("Failed to set avatar:", err)
} else {
puppet.AvatarSet = true
}
go puppet.updatePortalAvatar(ctx)
go puppet.updatePortalAvatar()
return true
}
func (puppet *Puppet) UpdateName(ctx context.Context, contact types.ContactInfo, forcePortalSync bool) bool {
func (puppet *Puppet) UpdateName(contact types.ContactInfo, forcePortalSync bool) bool {
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact)
if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality {
oldName := puppet.Displayname
puppet.Displayname = newName
puppet.NameQuality = quality
puppet.NameSet = false
err := puppet.DefaultIntent().SetDisplayName(ctx, newName)
err := puppet.DefaultIntent().SetDisplayName(newName)
if err == nil {
puppet.zlog.Debug().Str("old_name", oldName).Str("new_name", newName).Msg("Updated name")
puppet.log.Debugln("Updated name", oldName, "->", newName)
puppet.NameSet = true
go puppet.updatePortalName(ctx)
go puppet.updatePortalName()
} else {
puppet.zlog.Err(err).Msg("Failed to set displayname")
puppet.log.Warnln("Failed to set display name:", err)
}
return true
} else if forcePortalSync {
go puppet.updatePortalName(ctx)
go puppet.updatePortalName()
}
return false
}
func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool {
func (puppet *Puppet) UpdateContactInfo() bool {
if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
return false
}
@ -297,9 +281,9 @@ func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool {
"com.beeper.bridge.service": "whatsapp",
"com.beeper.bridge.network": "whatsapp",
}
err := puppet.DefaultIntent().BeeperUpdateProfile(ctx, contactInfo)
err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo)
if err != nil {
puppet.zlog.Err(err).Msg("Failed to store custom contact info in profile")
puppet.log.Warnln("Failed to store custom contact info in profile:", err)
return false
} else {
puppet.ContactInfoSet = true
@ -316,7 +300,7 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
}
}
func (puppet *Puppet) updatePortalAvatar(ctx context.Context) {
func (puppet *Puppet) updatePortalAvatar() {
puppet.updatePortalMeta(func(portal *Portal) {
if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) {
return
@ -324,31 +308,28 @@ func (puppet *Puppet) updatePortalAvatar(ctx context.Context) {
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.AvatarSet = false
defer portal.Update(nil)
if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() {
portal.UpdateBridgeInfo(ctx)
portal.UpdateBridgeInfo()
} else if len(portal.MXID) > 0 {
_, err := portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, puppet.AvatarURL)
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
if err != nil {
portal.zlog.Err(err).Msg("Failed to set avatar from puppet")
portal.log.Warnln("Failed to set avatar:", err)
} else {
portal.AvatarSet = true
portal.UpdateBridgeInfo(ctx)
portal.UpdateBridgeInfo()
}
}
err := portal.Update(ctx)
if err != nil {
portal.zlog.Err(err).Msg("Failed to save portal after updating avatar from puppet")
}
})
}
func (puppet *Puppet) updatePortalName(ctx context.Context) {
func (puppet *Puppet) updatePortalName() {
puppet.updatePortalMeta(func(portal *Portal) {
portal.UpdateName(ctx, puppet.Displayname, types.EmptyJID, true)
portal.UpdateName(puppet.Displayname, types.EmptyJID, true)
})
}
func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
if puppet == nil {
return
}
@ -356,67 +337,39 @@ func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoNam
source.EnqueuePuppetResync(puppet)
return
}
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.SyncContact").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
contact, err := source.Client.Store.Contacts.GetContact(puppet.JID)
if err != nil {
log.Err(err).
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("Failed to get contact info through user in SyncContact")
puppet.log.Warnfln("Failed to get contact info through %s in SyncContact: %v (sync reason: %s)", source.MXID, reason)
} else if !contact.Found {
log.Warn().
Stringer("source_mxid", source.MXID).
Str("sync_reason", reason).
Msg("No contact info found through user in SyncContact")
puppet.log.Warnfln("No contact info found through %s in SyncContact (sync reason: %s)", source.MXID, reason)
}
puppet.syncInternal(ctx, source, &contact, false, false)
puppet.Sync(source, &contact, false, false)
}
func (puppet *Puppet) Sync(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx).With().
Str("method", "Puppet.Sync").
Stringer("puppet_jid", puppet.JID).
Stringer("source_user_jid", source.JID).
Stringer("source_user_mxid", source.MXID).
Logger()
ctx = log.WithContext(ctx)
puppet.syncInternal(ctx, source, contact, forceAvatarSync, forcePortalSync)
}
func (puppet *Puppet) syncInternal(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
log := zerolog.Ctx(ctx)
func (puppet *Puppet) Sync(source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
puppet.syncLock.Lock()
defer puppet.syncLock.Unlock()
err := puppet.DefaultIntent().EnsureRegistered(ctx)
err := puppet.DefaultIntent().EnsureRegistered()
if err != nil {
log.Err(err).Msg("Failed to ensure registered")
puppet.log.Errorln("Failed to ensure registered:", err)
}
log.Debug().Stringer("source_jid", source.JID).Msg("Syncing info through user")
puppet.log.Debugfln("Syncing info through %s", source.JID)
update := false
if contact != nil {
if puppet.JID.User == source.JID.User {
contact.PushName = source.Client.Store.PushName
}
update = puppet.UpdateName(ctx, *contact, forcePortalSync) || update
update = puppet.UpdateName(*contact, forcePortalSync) || update
}
if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync {
update = puppet.UpdateAvatar(ctx, source, forcePortalSync) || update
update = puppet.UpdateAvatar(source, forcePortalSync) || update
}
update = puppet.UpdateContactInfo(ctx) || update
update = puppet.UpdateContactInfo() || update
if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) {
puppet.LastSync = time.Now()
err = puppet.Update(ctx)
if err != nil {
log.Err(err).Msg("Failed to save puppet after sync")
}
puppet.Update()
}
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
@ -19,6 +19,7 @@ package main
import (
"bytes"
"context"
"encoding/json"
"image"
"net/http"
"net/url"
@ -26,26 +27,33 @@ import (
"strings"
"time"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"golang.org/x/net/idna"
"google.golang.org/protobuf/proto"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
)
func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*event.BeeperLinkPreview {
type BeeperLinkPreview struct {
mautrix.RespPreviewURL
MatchedURL string `json:"matched_url"`
ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
}
func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*BeeperLinkPreview {
if msg.GetMatchedText() == "" {
return []*event.BeeperLinkPreview{}
return []*BeeperLinkPreview{}
}
output := &event.BeeperLinkPreview{
output := &BeeperLinkPreview{
MatchedURL: msg.GetMatchedText(),
LinkPreview: event.LinkPreview{
RespPreviewURL: mautrix.RespPreviewURL{
CanonicalURL: msg.GetCanonicalUrl(),
Title: msg.GetTitle(),
Description: msg.GetDescription(),
@ -60,7 +68,7 @@ func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *app
var err error
thumbnailData, err = source.Client.DownloadThumbnail(msg)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to download thumbnail for link preview")
portal.log.Warnfln("Failed to download thumbnail for link preview: %v", err)
}
}
if thumbnailData == nil && msg.JpegThumbnail != nil {
@ -85,9 +93,9 @@ func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *app
uploadMime = "application/octet-stream"
output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto}
}
resp, err := intent.UploadBytes(ctx, uploadData, uploadMime)
resp, err := intent.UploadBytes(uploadData, uploadMime)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload thumbnail for link preview")
portal.log.Warnfln("Failed to reupload thumbnail for link preview: %v", err)
} else {
if output.ImageEncryption != nil {
output.ImageEncryption.URL = resp.ContentURI.CUString()
@ -100,37 +108,36 @@ func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *app
output.Type = "video.other"
}
return []*event.BeeperLinkPreview{output}
return []*BeeperLinkPreview{output}
}
var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`)
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, content *event.MessageEventContent, dest *waProto.ExtendedTextMessage) bool {
log := zerolog.Ctx(ctx)
var preview *event.BeeperLinkPreview
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, evt *event.Event, dest *waProto.ExtendedTextMessage) bool {
var preview *BeeperLinkPreview
if content.BeeperLinkPreviews != nil {
// Note: this check explicitly happens after checking for nil: empty arrays are treated as no previews,
// but omitting the field means the bridge may look for URLs in the message text.
if len(content.BeeperLinkPreviews) == 0 {
rawPreview := gjson.GetBytes(evt.Content.VeryRaw, `com\.beeper\.linkpreviews`)
if rawPreview.Exists() && rawPreview.IsArray() {
var previews []BeeperLinkPreview
if err := json.Unmarshal([]byte(rawPreview.Raw), &previews); err != nil || len(previews) == 0 {
return false
}
// WhatsApp only supports a single preview.
preview = content.BeeperLinkPreviews[0]
preview = &previews[0]
} else if portal.bridge.Config.Bridge.URLPreviews {
if matchedURL := URLRegex.FindString(content.Body); len(matchedURL) == 0 {
if matchedURL := URLRegex.FindString(evt.Content.AsMessage().Body); len(matchedURL) == 0 {
return false
} else if parsed, err := url.Parse(matchedURL); err != nil {
return false
} else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
return false
} else if mxPreview, err := portal.MainIntent().GetURLPreview(ctx, parsed.String()); err != nil {
log.Err(err).Str("url", matchedURL).Msg("Failed to fetch URL preview")
} else if mxPreview, err := portal.MainIntent().GetURLPreview(parsed.String()); err != nil {
portal.log.Warnfln("Failed to fetch preview for %s: %v", matchedURL, err)
return false
} else {
preview = &event.BeeperLinkPreview{
LinkPreview: *mxPreview,
MatchedURL: matchedURL,
preview = &BeeperLinkPreview{
RespPreviewURL: *mxPreview,
MatchedURL: matchedURL,
}
}
}
@ -156,22 +163,22 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
imageMXC = preview.ImageEncryption.URL.ParseOrIgnore()
}
if !imageMXC.IsEmpty() {
data, err := portal.MainIntent().DownloadBytes(ctx, imageMXC)
data, err := portal.MainIntent().DownloadBytesContext(ctx, imageMXC)
if err != nil {
log.Err(err).Str("image_url", string(preview.ImageURL)).Msg("Failed to download URL preview image")
portal.log.Errorfln("Failed to download URL preview image %s in %s: %v", preview.ImageURL, evt.ID, err)
return true
}
if preview.ImageEncryption != nil {
err = preview.ImageEncryption.DecryptInPlace(data)
if err != nil {
log.Err(err).Msg("Failed to decrypt URL preview image")
portal.log.Errorfln("Failed to decrypt URL preview image in %s: %v", evt.ID, err)
return true
}
}
dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix())
uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail)
if err != nil {
log.Err(err).Msg("Failed to reupload URL preview thumbnail")
portal.log.Errorfln("Failed to upload URL preview thumbnail in %s: %v", evt.ID, err)
return true
}
dest.ThumbnailSha256 = uploadResp.FileSHA256
@ -181,7 +188,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
var width, height int
dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false)
if err != nil {
log.Err(err).Msg("Failed to create JPEG thumbnail for URL preview")
portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err)
}
if preview.ImageHeight > 0 && preview.ImageWidth > 0 {
dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth))

526
user.go

File diff suppressed because it is too large Load diff