mirror of
https://github.com/tulir/mautrix-whatsapp
synced 2024-06-14 08:58:22 +02:00
Update dependencies and lots of code
* Bump minimum Go version to 1.21 * Add contexts everywhere * Switch database code to new dbutil patterns * Finish switching away from maulogger
This commit is contained in:
parent
f8a22aab06
commit
103bfc31c6
3
.github/workflows/go.yml
vendored
3
.github/workflows/go.yml
vendored
|
@ -8,7 +8,8 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ["1.20", "1.21"]
|
||||
go-version: ["1.21", "1.22"]
|
||||
name: Lint ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
|
@ -15,6 +15,6 @@ repos:
|
|||
- id: go-vet-repo-mod
|
||||
|
||||
- repo: https://github.com/beeper/pre-commit-go
|
||||
rev: v0.2.2
|
||||
rev: v0.3.1
|
||||
hooks:
|
||||
- id: zerolog-ban-msgf
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
# v0.10.6 (unreleased)
|
||||
|
||||
* Bumped minimum Go version to 1.21.
|
||||
* Added 8-letter code pairing support to provisioning API.
|
||||
* Added more bugs to fix later.
|
||||
|
||||
# v0.10.5 (2023-12-16)
|
||||
|
||||
* Added support for sending media to channels.
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
@ -30,7 +30,7 @@ type AnalyticsClient struct {
|
|||
url string
|
||||
key string
|
||||
userID string
|
||||
log log.Logger
|
||||
log zerolog.Logger
|
||||
client http.Client
|
||||
}
|
||||
|
||||
|
@ -88,9 +88,9 @@ func (sc *AnalyticsClient) Track(userID id.UserID, event string, properties ...m
|
|||
props["bridge"] = "whatsapp"
|
||||
err := sc.trackSync(userID, event, props)
|
||||
if err != nil {
|
||||
sc.log.Errorfln("Error tracking %s: %v", event, err)
|
||||
sc.log.Err(err).Str("event", event).Msg("Error tracking event")
|
||||
} else {
|
||||
sc.log.Debugln("Tracked", event)
|
||||
sc.log.Debug().Str("event", event).Msg("Tracked event")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
|
||||
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,22 +17,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
)
|
||||
|
||||
type BackfillQueue struct {
|
||||
BackfillQuery *database.BackfillQuery
|
||||
BackfillQuery *database.BackfillTaskQuery
|
||||
reCheckChannels []chan bool
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (bq *BackfillQueue) ReCheck() {
|
||||
bq.log.Infofln("Sending re-checks to %d channels", len(bq.reCheckChannels))
|
||||
for _, channel := range bq.reCheckChannels {
|
||||
go func(c chan bool) {
|
||||
c <- true
|
||||
|
@ -40,12 +39,19 @@ func (bq *BackfillQueue) ReCheck() {
|
|||
}
|
||||
}
|
||||
|
||||
func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill {
|
||||
func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.BackfillTask {
|
||||
for {
|
||||
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(userID, waitForBackfillTypes) {
|
||||
if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(ctx, userID, waitForBackfillTypes) {
|
||||
// check for immediate when dealing with deferred
|
||||
if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil {
|
||||
backfill.MarkDispatched()
|
||||
if backfill, err := bq.BackfillQuery.GetNext(ctx, userID, backfillTypes); err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get next backfill task")
|
||||
} else if backfill != nil {
|
||||
err = backfill.MarkDispatched(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Int("queue_id", backfill.QueueID).
|
||||
Msg("Failed to mark backfill task as dispatched")
|
||||
}
|
||||
return backfill
|
||||
}
|
||||
}
|
||||
|
@ -58,38 +64,73 @@ func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []datab
|
|||
}
|
||||
|
||||
func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) {
|
||||
log := user.zlog.With().
|
||||
Str("action", "backfill request loop").
|
||||
Any("types", backfillTypes).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
reCheckChannel := make(chan bool)
|
||||
user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel)
|
||||
|
||||
for {
|
||||
req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
|
||||
user.log.Infofln("Handling backfill request %s", req)
|
||||
req := user.BackfillQueue.GetNextBackfill(ctx, user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel)
|
||||
log.Info().Any("backfill_request", req).Msg("Handling backfill request")
|
||||
log := log.With().
|
||||
Int("queue_id", req.QueueID).
|
||||
Stringer("portal_jid", req.Portal.JID).
|
||||
Logger()
|
||||
ctx := log.WithContext(ctx)
|
||||
|
||||
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, *req.Portal)
|
||||
if conv == nil {
|
||||
user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String())
|
||||
req.MarkDone()
|
||||
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, req.Portal)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get conversation data for backfill request")
|
||||
continue
|
||||
} else if conv == nil {
|
||||
log.Debug().Msg("Couldn't find conversation data for backfill request")
|
||||
err = req.MarkDone(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to mark backfill request as done after data was not found")
|
||||
}
|
||||
continue
|
||||
}
|
||||
portal := user.GetPortalByJID(conv.PortalKey.JID)
|
||||
|
||||
// Update the client store with basic chat settings.
|
||||
if conv.MuteEndTime.After(time.Now()) {
|
||||
user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
|
||||
err = user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save muted until time from conversation data")
|
||||
}
|
||||
}
|
||||
if conv.Archived {
|
||||
user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
|
||||
err = user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save archived state from conversation data")
|
||||
}
|
||||
}
|
||||
if conv.Pinned > 0 {
|
||||
user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
|
||||
err = user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save pinned state from conversation data")
|
||||
}
|
||||
}
|
||||
|
||||
if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration {
|
||||
log.Debug().
|
||||
Uint32("old_time", portal.ExpirationTime).
|
||||
Uint32("new_time", *conv.EphemeralExpiration).
|
||||
Msg("Updating portal ephemeral expiration time")
|
||||
portal.ExpirationTime = *conv.EphemeralExpiration
|
||||
portal.Update(nil)
|
||||
err = portal.Update(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save portal after updating expiration time")
|
||||
}
|
||||
}
|
||||
|
||||
user.backfillInChunks(req, conv, portal)
|
||||
req.MarkDone()
|
||||
user.backfillInChunks(ctx, req, conv, portal)
|
||||
err = req.MarkDone(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to mark backfill request as done after backfilling")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Requ
|
|||
remote = remote.Fill(user)
|
||||
resp.RemoteStates[remote.RemoteID] = remote
|
||||
}
|
||||
user.log.Debugfln("Responding bridge state in bridge status endpoint: %+v", resp)
|
||||
user.zlog.Debug().Any("response_data", &resp).Msg("Responding bridge state in bridge status endpoint")
|
||||
jsonResponse(w, http.StatusOK, &resp)
|
||||
if len(resp.RemoteStates) > 0 {
|
||||
user.BridgeState.SetPrev(remote)
|
||||
|
|
154
commands.go
154
commands.go
|
@ -29,6 +29,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/skip2/go-qrcode"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
|
@ -37,7 +38,6 @@ import (
|
|||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/bridge/commands"
|
||||
"maunium.net/go/mautrix/bridge/status"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -117,7 +117,10 @@ func fnSetRelay(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
|
||||
} else {
|
||||
ce.Portal.RelayUserID = ce.User.MXID
|
||||
ce.Portal.Update(nil)
|
||||
err := ce.Portal.Update(ce.Ctx)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to save portal after setting relay user")
|
||||
}
|
||||
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account")
|
||||
}
|
||||
}
|
||||
|
@ -139,7 +142,10 @@ func fnUnsetRelay(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
|
||||
} else {
|
||||
ce.Portal.RelayUserID = ""
|
||||
ce.Portal.Update(nil)
|
||||
err := ce.Portal.Update(ce.Ctx)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to save portal after clearing relay user")
|
||||
}
|
||||
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
|
||||
}
|
||||
}
|
||||
|
@ -246,7 +252,7 @@ func fnJoin(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Failed to join group: %v", err)
|
||||
return
|
||||
}
|
||||
ce.Log.Debugln("%s successfully joined group %s", ce.User.MXID, jid)
|
||||
ce.ZLog.Debug().Stringer("group_jid", jid).Msg("User successfully joined WhatsApp group with link")
|
||||
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
|
||||
} else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) {
|
||||
info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0])
|
||||
|
@ -259,14 +265,14 @@ func fnJoin(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Failed to follow channel: %v", err)
|
||||
return
|
||||
}
|
||||
ce.Log.Debugln("%s successfully followed channel %s", ce.User.MXID, info.ID)
|
||||
ce.ZLog.Debug().Stringer("channel_jid", info.ID).Msg("User successfully followed WhatsApp channel with link")
|
||||
ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID)
|
||||
} else {
|
||||
ce.Reply("That doesn't look like a WhatsApp invite link")
|
||||
}
|
||||
}
|
||||
|
||||
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
|
||||
func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage, error) {
|
||||
var data json.RawMessage
|
||||
if evt.Type != event.EventEncrypted {
|
||||
data = evt.Content.VeryRaw
|
||||
|
@ -275,7 +281,7 @@ func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, e
|
|||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
return nil, err
|
||||
}
|
||||
decrypted, err := crypto.Decrypt(evt)
|
||||
decrypted, err := ce.Bridge.Crypto.Decrypt(ce.Ctx, evt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -311,11 +317,11 @@ var cmdAccept = &commands.FullHandler{
|
|||
func fnAccept(ce *WrappedCommandEvent) {
|
||||
if len(ce.ReplyTo) == 0 {
|
||||
ce.Reply("You must reply to a group invite message when using this command.")
|
||||
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.RoomID, ce.ReplyTo); err != nil {
|
||||
ce.Log.Errorln("Failed to get event %s to handle !wa accept command: %v", ce.ReplyTo, err)
|
||||
} else if evt, err := ce.Portal.MainIntent().GetEvent(ce.Ctx, ce.RoomID, ce.ReplyTo); err != nil {
|
||||
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to get reply target event to handle !wa accept command")
|
||||
ce.Reply("Failed to get reply event")
|
||||
} else if rawContent, err := tryDecryptEvent(ce.Bridge.Crypto, evt); err != nil {
|
||||
ce.Log.Errorln("Failed to decrypt event %s to handle !wa accept command: %v", ce.ReplyTo, err)
|
||||
} else if rawContent, err := tryDecryptEvent(ce, evt); err != nil {
|
||||
ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to decrypt reply target event to handle !wa accept command")
|
||||
ce.Reply("Failed to decrypt reply event")
|
||||
} else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil {
|
||||
ce.Reply("That doesn't look like a group invite message.")
|
||||
|
@ -344,16 +350,16 @@ func fnCreate(ce *WrappedCommandEvent) {
|
|||
return
|
||||
}
|
||||
|
||||
members, err := ce.Bot.JoinedMembers(ce.RoomID)
|
||||
members, err := ce.Bot.JoinedMembers(ce.Ctx, ce.RoomID)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to get room members: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var roomNameEvent event.RoomNameEventContent
|
||||
err = ce.Bot.StateEvent(ce.RoomID, event.StateRoomName, "", &roomNameEvent)
|
||||
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateRoomName, "", &roomNameEvent)
|
||||
if err != nil && !errors.Is(err, mautrix.MNotFound) {
|
||||
ce.Log.Errorln("Failed to get room name to create group:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to get room name to create group")
|
||||
ce.Reply("Failed to get room name")
|
||||
return
|
||||
} else if len(roomNameEvent.Name) == 0 {
|
||||
|
@ -362,15 +368,17 @@ func fnCreate(ce *WrappedCommandEvent) {
|
|||
}
|
||||
|
||||
var encryptionEvent event.EncryptionEventContent
|
||||
err = ce.Bot.StateEvent(ce.RoomID, event.StateEncryption, "", &encryptionEvent)
|
||||
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateEncryption, "", &encryptionEvent)
|
||||
if err != nil && !errors.Is(err, mautrix.MNotFound) {
|
||||
ce.ZLog.Err(err).Msg("Failed to get room encryption status to create group")
|
||||
ce.Reply("Failed to get room encryption status")
|
||||
return
|
||||
}
|
||||
|
||||
var createEvent event.CreateEventContent
|
||||
err = ce.Bot.StateEvent(ce.RoomID, event.StateCreate, "", &createEvent)
|
||||
err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateCreate, "", &createEvent)
|
||||
if err != nil && !errors.Is(err, mautrix.MNotFound) {
|
||||
ce.ZLog.Err(err).Msg("Failed to get room create event to create group")
|
||||
ce.Reply("Failed to get room create event")
|
||||
return
|
||||
}
|
||||
|
@ -395,7 +403,11 @@ func fnCreate(ce *WrappedCommandEvent) {
|
|||
// TODO check m.space.parent to create rooms directly in communities
|
||||
|
||||
messageID := ce.User.Client.GenerateMessageID()
|
||||
ce.Log.Infofln("Creating group for %s with name %s and participants %+v (create key: %s)", ce.RoomID, roomNameEvent.Name, participants, messageID)
|
||||
ce.ZLog.Info().
|
||||
Str("room_name", roomNameEvent.Name).
|
||||
Any("participants", participants).
|
||||
Str("create_key", messageID).
|
||||
Msg("Creating WhatsApp group for Matrix room")
|
||||
ce.User.createKeyDedup = messageID
|
||||
resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{
|
||||
CreateKey: messageID,
|
||||
|
@ -409,21 +421,25 @@ func fnCreate(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Failed to create group: %v", err)
|
||||
return
|
||||
}
|
||||
ce.ZLog.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("group_jid", resp.JID.String())
|
||||
})
|
||||
portal := ce.User.GetPortalByJID(resp.JID)
|
||||
portal.roomCreateLock.Lock()
|
||||
defer portal.roomCreateLock.Unlock()
|
||||
if len(portal.MXID) != 0 {
|
||||
portal.log.Warnln("Detected race condition in room creation")
|
||||
ce.ZLog.Warn().Msg("Detected race condition in room creation")
|
||||
// TODO race condition, clean up the old room
|
||||
}
|
||||
portal.MXID = ce.RoomID
|
||||
portal.updateLogger()
|
||||
portal.Name = roomNameEvent.Name
|
||||
portal.IsParent = resp.IsParent
|
||||
portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1
|
||||
if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default {
|
||||
_, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
|
||||
_, err = portal.MainIntent().SendStateEvent(ce.Ctx, portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent())
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to enable encryption in room:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to enable encryption in room")
|
||||
if errors.Is(err, mautrix.MForbidden) {
|
||||
ce.Reply("I don't seem to have permission to enable encryption in this room.")
|
||||
} else {
|
||||
|
@ -433,8 +449,11 @@ func fnCreate(ce *WrappedCommandEvent) {
|
|||
portal.Encrypted = true
|
||||
}
|
||||
|
||||
portal.Update(nil)
|
||||
portal.UpdateBridgeInfo()
|
||||
err = portal.Update(ce.Ctx)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to save portal after creating group")
|
||||
}
|
||||
portal.UpdateBridgeInfo(ce.Ctx)
|
||||
ce.User.createKeyDedup = ""
|
||||
|
||||
ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID)
|
||||
|
@ -512,7 +531,7 @@ func fnLogin(ce *WrappedCommandEvent) {
|
|||
}
|
||||
}
|
||||
if qrEventID != "" {
|
||||
_, _ = ce.Bot.RedactEvent(ce.RoomID, qrEventID)
|
||||
_, _ = ce.Bot.RedactEvent(ce.Ctx, ce.RoomID, qrEventID)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -529,9 +548,9 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
|
|||
if len(prevEvent) != 0 {
|
||||
content.SetEdit(prevEvent)
|
||||
}
|
||||
resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content)
|
||||
resp, err := ce.Bot.SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, &content)
|
||||
if err != nil {
|
||||
user.log.Errorln("Failed to send edited QR code to user:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to send edited QR code to user")
|
||||
} else if len(prevEvent) == 0 {
|
||||
prevEvent = resp.EventID
|
||||
}
|
||||
|
@ -541,16 +560,16 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even
|
|||
func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) {
|
||||
qrCode, err := qrcode.Encode(code, qrcode.Low, 256)
|
||||
if err != nil {
|
||||
user.log.Errorln("Failed to encode QR code:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to encode QR code")
|
||||
ce.Reply("Failed to encode QR code: %v", err)
|
||||
return id.ContentURI{}, false
|
||||
}
|
||||
|
||||
bot := user.bridge.AS.BotClient()
|
||||
|
||||
resp, err := bot.UploadBytes(qrCode, "image/png")
|
||||
resp, err := bot.UploadBytes(ce.Ctx, qrCode, "image/png")
|
||||
if err != nil {
|
||||
user.log.Errorln("Failed to upload QR code:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to upload QR code")
|
||||
ce.Reply("Failed to upload QR code: %v", err)
|
||||
return id.ContentURI{}, false
|
||||
}
|
||||
|
@ -578,14 +597,14 @@ func fnLogout(ce *WrappedCommandEvent) {
|
|||
puppet.ClearCustomMXID()
|
||||
err := ce.User.Client.Logout()
|
||||
if err != nil {
|
||||
ce.User.log.Warnln("Error while logging out:", err)
|
||||
ce.ZLog.Err(err).Msg("Unknown error while logging out")
|
||||
ce.Reply("Unknown error while logging out: %v", err)
|
||||
return
|
||||
}
|
||||
ce.User.Session = nil
|
||||
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
|
||||
ce.User.DeleteConnection()
|
||||
ce.User.DeleteSession()
|
||||
ce.User.DeleteSession(ce.Ctx)
|
||||
ce.Reply("Logged out successfully.")
|
||||
}
|
||||
|
||||
|
@ -620,10 +639,13 @@ func fnTogglePresence(ce *WrappedCommandEvent) {
|
|||
if ce.User.IsLoggedIn() {
|
||||
err := ce.User.Client.SendPresence(newPresence)
|
||||
if err != nil {
|
||||
ce.User.log.Warnln("Failed to set presence:", err)
|
||||
ce.ZLog.Err(err).Msg("Failed to send presence to WhatsApp")
|
||||
}
|
||||
}
|
||||
customPuppet.Update()
|
||||
err := customPuppet.Update(ce.Ctx)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to save puppet after toggling presence")
|
||||
}
|
||||
}
|
||||
|
||||
var cmdDeleteSession = &commands.FullHandler{
|
||||
|
@ -642,7 +664,7 @@ func fnDeleteSession(ce *WrappedCommandEvent) {
|
|||
}
|
||||
ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
|
||||
ce.User.DeleteConnection()
|
||||
ce.User.DeleteSession()
|
||||
ce.User.DeleteSession(ce.Ctx)
|
||||
ce.Reply("Session information purged")
|
||||
}
|
||||
|
||||
|
@ -716,19 +738,19 @@ func fnPing(ce *WrappedCommandEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
func canDeletePortal(portal *Portal, userID id.UserID) bool {
|
||||
func canDeletePortal(ce *WrappedCommandEvent, portal *Portal) bool {
|
||||
if len(portal.MXID) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
members, err := portal.MainIntent().JoinedMembers(portal.MXID)
|
||||
members, err := portal.MainIntent().JoinedMembers(ce.Ctx, portal.MXID)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Failed to get joined members to check if portal can be deleted by %s: %v", userID, err)
|
||||
ce.ZLog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get joined members to check if portal can be deleted by user")
|
||||
return false
|
||||
}
|
||||
for otherUser := range members.Joined {
|
||||
_, isPuppet := portal.bridge.ParsePuppetMXID(otherUser)
|
||||
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == userID {
|
||||
if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == ce.User.MXID {
|
||||
continue
|
||||
}
|
||||
user := portal.bridge.GetUserByMXID(otherUser)
|
||||
|
@ -750,14 +772,14 @@ var cmdDeletePortal = &commands.FullHandler{
|
|||
}
|
||||
|
||||
func fnDeletePortal(ce *WrappedCommandEvent) {
|
||||
if !ce.User.Admin && !canDeletePortal(ce.Portal, ce.User.MXID) {
|
||||
if !ce.User.Admin && !canDeletePortal(ce, ce.Portal) {
|
||||
ce.Reply("Only bridge admins can delete portals with other Matrix users")
|
||||
return
|
||||
}
|
||||
|
||||
ce.Portal.log.Infoln(ce.User.MXID, "requested deletion of portal.")
|
||||
ce.Portal.Delete()
|
||||
ce.Portal.Cleanup(false)
|
||||
ce.ZLog.Info().Msg("User requested deletion of current portal")
|
||||
ce.Portal.Delete(ce.Ctx)
|
||||
ce.Portal.Cleanup(ce.Ctx, false)
|
||||
}
|
||||
|
||||
var cmdDeleteAllPortals = &commands.FullHandler{
|
||||
|
@ -778,7 +800,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
|
|||
} else {
|
||||
portalsToDelete = portals[:0]
|
||||
for _, portal := range portals {
|
||||
if canDeletePortal(portal, ce.User.MXID) {
|
||||
if canDeletePortal(ce, portal) {
|
||||
portalsToDelete = append(portalsToDelete, portal)
|
||||
}
|
||||
}
|
||||
|
@ -790,7 +812,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
|
|||
|
||||
leave := func(portal *Portal) {
|
||||
if len(portal.MXID) > 0 {
|
||||
_, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{
|
||||
_, _ = portal.MainIntent().KickUser(ce.Ctx, portal.MXID, &mautrix.ReqKickUser{
|
||||
Reason: "Deleting portal",
|
||||
UserID: ce.User.MXID,
|
||||
})
|
||||
|
@ -801,21 +823,21 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) {
|
|||
intent := customPuppet.CustomIntent()
|
||||
leave = func(portal *Portal) {
|
||||
if len(portal.MXID) > 0 {
|
||||
_, _ = intent.LeaveRoom(portal.MXID)
|
||||
_, _ = intent.ForgetRoom(portal.MXID)
|
||||
_, _ = intent.LeaveRoom(ce.Ctx, portal.MXID)
|
||||
_, _ = intent.ForgetRoom(ce.Ctx, portal.MXID)
|
||||
}
|
||||
}
|
||||
}
|
||||
ce.Reply("Found %d portals, deleting...", len(portalsToDelete))
|
||||
for _, portal := range portalsToDelete {
|
||||
portal.Delete()
|
||||
portal.Delete(ce.Ctx)
|
||||
leave(portal)
|
||||
}
|
||||
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
|
||||
|
||||
go func() {
|
||||
for _, portal := range portalsToDelete {
|
||||
portal.Cleanup(false)
|
||||
portal.Cleanup(ce.Ctx, false)
|
||||
}
|
||||
ce.Reply("Finished background cleanup of deleted portal rooms.")
|
||||
}()
|
||||
|
@ -882,7 +904,7 @@ func fnList(ce *WrappedCommandEvent) {
|
|||
}
|
||||
var err error
|
||||
page := 1
|
||||
max := 100
|
||||
maxPerPage := 100
|
||||
if len(ce.Args) > 1 {
|
||||
page, err = strconv.Atoi(ce.Args[1])
|
||||
if err != nil || page <= 0 {
|
||||
|
@ -891,11 +913,11 @@ func fnList(ce *WrappedCommandEvent) {
|
|||
}
|
||||
}
|
||||
if len(ce.Args) > 2 {
|
||||
max, err = strconv.Atoi(ce.Args[2])
|
||||
if err != nil || max <= 0 {
|
||||
maxPerPage, err = strconv.Atoi(ce.Args[2])
|
||||
if err != nil || maxPerPage <= 0 {
|
||||
ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2])
|
||||
return
|
||||
} else if max > 400 {
|
||||
} else if maxPerPage > 400 {
|
||||
ce.Reply("Warning: a high number of items per page may fail to send a reply")
|
||||
}
|
||||
}
|
||||
|
@ -924,8 +946,8 @@ func fnList(ce *WrappedCommandEvent) {
|
|||
ce.Reply("No %s found", strings.ToLower(typeName))
|
||||
return
|
||||
}
|
||||
pages := int(math.Ceil(float64(len(result)) / float64(max)))
|
||||
if (page-1)*max >= len(result) {
|
||||
pages := int(math.Ceil(float64(len(result)) / float64(maxPerPage)))
|
||||
if (page-1)*maxPerPage >= len(result) {
|
||||
if pages == 1 {
|
||||
ce.Reply("There is only 1 page of %s", strings.ToLower(typeName))
|
||||
} else {
|
||||
|
@ -933,11 +955,11 @@ func fnList(ce *WrappedCommandEvent) {
|
|||
}
|
||||
return
|
||||
}
|
||||
lastIndex := page * max
|
||||
lastIndex := page * maxPerPage
|
||||
if lastIndex > len(result) {
|
||||
lastIndex = len(result)
|
||||
}
|
||||
result = result[(page-1)*max : lastIndex]
|
||||
result = result[(page-1)*maxPerPage : lastIndex]
|
||||
ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n"))
|
||||
}
|
||||
|
||||
|
@ -1036,13 +1058,13 @@ func fnOpen(ce *WrappedCommandEvent) {
|
|||
}
|
||||
jid = newsletterMetadata.ID
|
||||
}
|
||||
ce.Log.Debugln("Importing", jid, "for", ce.User.MXID)
|
||||
ce.ZLog.Debug().Stringer("chat_jid", jid).Msg("Importing chat for user")
|
||||
portal := ce.User.GetPortalByJID(jid)
|
||||
if len(portal.MXID) > 0 {
|
||||
portal.UpdateMatrixRoom(ce.User, groupInfo, newsletterMetadata)
|
||||
portal.UpdateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata)
|
||||
ce.Reply("Portal room synced.")
|
||||
} else {
|
||||
err = portal.CreateMatrixRoom(ce.User, groupInfo, newsletterMetadata, true, true)
|
||||
err = portal.CreateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata, true, true)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to create room: %v", err)
|
||||
} else {
|
||||
|
@ -1085,7 +1107,7 @@ func fnPM(ce *WrappedCommandEvent) {
|
|||
return
|
||||
}
|
||||
|
||||
portal, puppet, justCreated, err := user.StartPM(targetUser.JID, "manual PM command")
|
||||
portal, puppet, justCreated, err := user.StartPM(ce.Ctx, targetUser.JID, "manual PM command")
|
||||
if err != nil {
|
||||
ce.Reply("Failed to create portal room: %v", err)
|
||||
} else if !justCreated {
|
||||
|
@ -1154,11 +1176,16 @@ func fnSync(ce *WrappedCommandEvent) {
|
|||
ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge")
|
||||
return
|
||||
}
|
||||
keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID)
|
||||
keys, err := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.Ctx, ce.User.JID)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to get list of private chats not in space")
|
||||
ce.Reply("Failed to get list of private chats not in space")
|
||||
return
|
||||
}
|
||||
count := 0
|
||||
for _, key := range keys {
|
||||
portal := ce.Bridge.GetPortalByJID(key)
|
||||
portal.addToPersonalSpace(ce.User)
|
||||
portal.addToPersonalSpace(ce.Ctx, ce.User)
|
||||
count++
|
||||
}
|
||||
plural := "s"
|
||||
|
@ -1208,6 +1235,9 @@ func fnDisappearingTimer(ce *WrappedCommandEvent) {
|
|||
ce.Portal.ExpirationTime = prevExpirationTime
|
||||
return
|
||||
}
|
||||
ce.Portal.Update(nil)
|
||||
err = ce.Portal.Update(ce.Ctx)
|
||||
if err != nil {
|
||||
ce.ZLog.Err(err).Msg("Failed to save portal after setting disappearing timer")
|
||||
}
|
||||
ce.React("✅")
|
||||
}
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
@ -24,8 +27,11 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error
|
|||
puppet.CustomMXID = mxid
|
||||
puppet.AccessToken = accessToken
|
||||
puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence
|
||||
puppet.Update()
|
||||
err := puppet.StartCustomMXID(false)
|
||||
err := puppet.Update(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save access token: %w", err)
|
||||
}
|
||||
err = puppet.StartCustomMXID(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -45,12 +51,15 @@ func (puppet *Puppet) ClearCustomMXID() {
|
|||
puppet.customIntent = nil
|
||||
puppet.customUser = nil
|
||||
if save {
|
||||
puppet.Update()
|
||||
err := puppet.Update(context.TODO())
|
||||
if err != nil {
|
||||
puppet.zlog.Err(err).Msg("Failed to clear custom MXID")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
|
||||
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
|
||||
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
|
||||
if err != nil {
|
||||
puppet.ClearCustomMXID()
|
||||
return err
|
||||
|
@ -60,11 +69,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
|
|||
puppet.bridge.puppetsLock.Unlock()
|
||||
if puppet.AccessToken != newAccessToken {
|
||||
puppet.AccessToken = newAccessToken
|
||||
puppet.Update()
|
||||
err = puppet.Update(context.TODO())
|
||||
}
|
||||
puppet.customIntent = newIntent
|
||||
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) tryAutomaticDoublePuppeting() {
|
||||
|
|
|
@ -1,340 +0,0 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type BackfillType int
|
||||
|
||||
const (
|
||||
BackfillImmediate BackfillType = 0
|
||||
BackfillForward BackfillType = 100
|
||||
BackfillDeferred BackfillType = 200
|
||||
)
|
||||
|
||||
func (bt BackfillType) String() string {
|
||||
switch bt {
|
||||
case BackfillImmediate:
|
||||
return "IMMEDIATE"
|
||||
case BackfillForward:
|
||||
return "FORWARD"
|
||||
case BackfillDeferred:
|
||||
return "DEFERRED"
|
||||
}
|
||||
return "UNKNOWN"
|
||||
}
|
||||
|
||||
type BackfillQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
backfillQueryLock sync.Mutex
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) New() *Backfill {
|
||||
return &Backfill{
|
||||
db: bq.db,
|
||||
log: bq.log,
|
||||
Portal: &PortalKey{},
|
||||
}
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill {
|
||||
return &Backfill{
|
||||
db: bq.db,
|
||||
log: bq.log,
|
||||
UserID: userID,
|
||||
BackfillType: backfillType,
|
||||
Priority: priority,
|
||||
Portal: portal,
|
||||
TimeStart: timeStart,
|
||||
MaxBatchEvents: maxBatchEvents,
|
||||
MaxTotalEvents: maxTotalEvents,
|
||||
BatchDelay: batchDelay,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
getNextBackfillQuery = `
|
||||
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND type IN (%s)
|
||||
AND (
|
||||
dispatch_time IS NULL
|
||||
OR (
|
||||
dispatch_time < $2
|
||||
AND completed_at IS NULL
|
||||
)
|
||||
)
|
||||
ORDER BY type, priority, queue_id
|
||||
LIMIT 1
|
||||
`
|
||||
getUnstartedOrInFlightQuery = `
|
||||
SELECT 1
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND type IN (%s)
|
||||
AND (dispatch_time IS NULL OR completed_at IS NULL)
|
||||
LIMIT 1
|
||||
`
|
||||
)
|
||||
|
||||
// GetNext returns the next backfill to perform
|
||||
func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) {
|
||||
bq.backfillQueryLock.Lock()
|
||||
defer bq.backfillQueryLock.Unlock()
|
||||
|
||||
var types []string
|
||||
for _, backfillType := range backfillTypes {
|
||||
types = append(types, strconv.Itoa(int(backfillType)))
|
||||
}
|
||||
rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID, time.Now().Add(-15*time.Minute))
|
||||
if err != nil || rows == nil {
|
||||
bq.log.Errorfln("Failed to query next backfill queue job: %v", err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
backfill = bq.New().Scan(rows)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) HasUnstartedOrInFlightOfType(userID id.UserID, backfillTypes []BackfillType) bool {
|
||||
if len(backfillTypes) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
bq.backfillQueryLock.Lock()
|
||||
defer bq.backfillQueryLock.Unlock()
|
||||
|
||||
types := []string{}
|
||||
for _, backfillType := range backfillTypes {
|
||||
types = append(types, strconv.Itoa(int(backfillType)))
|
||||
}
|
||||
rows, err := bq.db.Query(fmt.Sprintf(getUnstartedOrInFlightQuery, strings.Join(types, ",")), userID)
|
||||
if err != nil || rows == nil {
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
bq.log.Warnfln("Failed to query backfill queue jobs: %v", err)
|
||||
}
|
||||
// No rows means that there are no unstarted or in flight backfill
|
||||
// requests.
|
||||
return false
|
||||
}
|
||||
defer rows.Close()
|
||||
return rows.Next()
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) DeleteAll(userID id.UserID) {
|
||||
bq.backfillQueryLock.Lock()
|
||||
defer bq.backfillQueryLock.Unlock()
|
||||
_, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID)
|
||||
if err != nil {
|
||||
bq.log.Warnfln("Failed to delete backfill queue items for %s: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) DeleteAllForPortal(userID id.UserID, portalKey PortalKey) {
|
||||
bq.backfillQueryLock.Lock()
|
||||
defer bq.backfillQueryLock.Unlock()
|
||||
_, err := bq.db.Exec(`
|
||||
DELETE FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND portal_jid=$2
|
||||
AND portal_receiver=$3
|
||||
`, userID, portalKey.JID, portalKey.Receiver)
|
||||
if err != nil {
|
||||
bq.log.Warnfln("Failed to delete backfill queue items for %s/%s: %v", userID, portalKey.JID, err)
|
||||
}
|
||||
}
|
||||
|
||||
type Backfill struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
// Fields
|
||||
QueueID int
|
||||
UserID id.UserID
|
||||
BackfillType BackfillType
|
||||
Priority int
|
||||
Portal *PortalKey
|
||||
TimeStart *time.Time
|
||||
MaxBatchEvents int
|
||||
MaxTotalEvents int
|
||||
BatchDelay int
|
||||
DispatchTime *time.Time
|
||||
CompletedAt *time.Time
|
||||
}
|
||||
|
||||
func (b *Backfill) String() string {
|
||||
return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
|
||||
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
|
||||
)
|
||||
}
|
||||
|
||||
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
|
||||
var maxTotalEvents, batchDelay sql.NullInt32
|
||||
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &maxTotalEvents, &batchDelay)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
b.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
b.MaxTotalEvents = int(maxTotalEvents.Int32)
|
||||
b.BatchDelay = int(batchDelay.Int32)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Backfill) Insert() {
|
||||
b.db.Backfill.backfillQueryLock.Lock()
|
||||
defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
rows, err := b.db.Query(`
|
||||
INSERT INTO backfill_queue
|
||||
(user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING queue_id
|
||||
`, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt)
|
||||
defer rows.Close()
|
||||
if err != nil || !rows.Next() {
|
||||
b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
|
||||
return
|
||||
}
|
||||
err = rows.Scan(&b.QueueID)
|
||||
if err != nil {
|
||||
b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backfill) MarkDispatched() {
|
||||
b.db.Backfill.backfillQueryLock.Lock()
|
||||
defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
if b.QueueID == 0 {
|
||||
b.log.Errorfln("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?")
|
||||
return
|
||||
}
|
||||
_, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
|
||||
if err != nil {
|
||||
b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backfill) MarkDone() {
|
||||
b.db.Backfill.backfillQueryLock.Lock()
|
||||
defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
if b.QueueID == 0 {
|
||||
b.log.Errorfln("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?")
|
||||
return
|
||||
}
|
||||
_, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID)
|
||||
if err != nil {
|
||||
b.log.Warnfln("Failed to mark %s/%s as complete: %v", b.BackfillType, b.Priority, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) NewBackfillState(userID id.UserID, portalKey *PortalKey) *BackfillState {
|
||||
return &BackfillState{
|
||||
db: bq.db,
|
||||
log: bq.log,
|
||||
UserID: userID,
|
||||
Portal: portalKey,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
getBackfillState = `
|
||||
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
|
||||
FROM backfill_state
|
||||
WHERE user_mxid=$1
|
||||
AND portal_jid=$2
|
||||
AND portal_receiver=$3
|
||||
`
|
||||
)
|
||||
|
||||
type BackfillState struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
// Fields
|
||||
UserID id.UserID
|
||||
Portal *PortalKey
|
||||
ProcessingBatch bool
|
||||
BackfillComplete bool
|
||||
FirstExpectedTimestamp uint64
|
||||
}
|
||||
|
||||
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
|
||||
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
b.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BackfillState) Upsert() {
|
||||
_, err := b.db.Exec(`
|
||||
INSERT INTO backfill_state
|
||||
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
|
||||
DO UPDATE SET
|
||||
processing_batch=EXCLUDED.processing_batch,
|
||||
backfill_complete=EXCLUDED.backfill_complete,
|
||||
first_expected_ts=EXCLUDED.first_expected_ts`,
|
||||
b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp)
|
||||
if err != nil {
|
||||
b.log.Warnfln("Failed to insert backfill state for %s: %v", b.Portal.JID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackfillState) SetProcessingBatch(processing bool) {
|
||||
b.ProcessingBatch = processing
|
||||
b.Upsert()
|
||||
}
|
||||
|
||||
func (bq *BackfillQuery) GetBackfillState(userID id.UserID, portalKey *PortalKey) (backfillState *BackfillState) {
|
||||
rows, err := bq.db.Query(getBackfillState, userID, portalKey.JID, portalKey.Receiver)
|
||||
if err != nil || rows == nil {
|
||||
bq.log.Error(err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
backfillState = bq.NewBackfillState(userID, portalKey).Scan(rows)
|
||||
}
|
||||
return
|
||||
}
|
253
database/backfillqueue.go
Normal file
253
database/backfillqueue.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type BackfillType int
|
||||
|
||||
const (
|
||||
BackfillImmediate BackfillType = 0
|
||||
BackfillForward BackfillType = 100
|
||||
BackfillDeferred BackfillType = 200
|
||||
)
|
||||
|
||||
func (bt BackfillType) String() string {
|
||||
switch bt {
|
||||
case BackfillImmediate:
|
||||
return "IMMEDIATE"
|
||||
case BackfillForward:
|
||||
return "FORWARD"
|
||||
case BackfillDeferred:
|
||||
return "DEFERRED"
|
||||
}
|
||||
return "UNKNOWN"
|
||||
}
|
||||
|
||||
type BackfillTaskQuery struct {
|
||||
*dbutil.QueryHelper[*BackfillTask]
|
||||
|
||||
//backfillQueryLock sync.Mutex
|
||||
}
|
||||
|
||||
func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask {
|
||||
return &BackfillTask{qh: qh}
|
||||
}
|
||||
|
||||
func (bq *BackfillTaskQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *BackfillTask {
|
||||
return &BackfillTask{
|
||||
qh: bq.QueryHelper,
|
||||
|
||||
UserID: userID,
|
||||
BackfillType: backfillType,
|
||||
Priority: priority,
|
||||
Portal: portal,
|
||||
TimeStart: timeStart,
|
||||
MaxBatchEvents: maxBatchEvents,
|
||||
MaxTotalEvents: maxTotalEvents,
|
||||
BatchDelay: batchDelay,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
getNextBackfillTaskQueryTemplate = `
|
||||
SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND type IN (%s)
|
||||
AND (
|
||||
dispatch_time IS NULL
|
||||
OR (
|
||||
dispatch_time < $2
|
||||
AND completed_at IS NULL
|
||||
)
|
||||
)
|
||||
ORDER BY type, priority, queue_id
|
||||
LIMIT 1
|
||||
`
|
||||
getUnstartedOrInFlightBackfillTaskQueryTemplate = `
|
||||
SELECT 1
|
||||
FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND type IN (%s)
|
||||
AND (dispatch_time IS NULL OR completed_at IS NULL)
|
||||
LIMIT 1
|
||||
`
|
||||
deleteBackfillQueueForUserQuery = "DELETE FROM backfill_queue WHERE user_mxid=$1"
|
||||
deleteBackfillQueueForPortalQuery = `
|
||||
DELETE FROM backfill_queue
|
||||
WHERE user_mxid=$1
|
||||
AND portal_jid=$2
|
||||
AND portal_receiver=$3
|
||||
`
|
||||
insertBackfillTaskQuery = `
|
||||
INSERT INTO backfill_queue (
|
||||
user_mxid, type, priority, portal_jid, portal_receiver, time_start,
|
||||
max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING queue_id
|
||||
`
|
||||
markBackfillTaskDispatchedQuery = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"
|
||||
markBackfillTaskDoneQuery = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2"
|
||||
)
|
||||
|
||||
func typesToString(backfillTypes []BackfillType) string {
|
||||
types := make([]string, len(backfillTypes))
|
||||
for i, backfillType := range backfillTypes {
|
||||
types[i] = strconv.Itoa(int(backfillType))
|
||||
}
|
||||
return strings.Join(types, ",")
|
||||
}
|
||||
|
||||
// GetNext returns the next backfill to perform
|
||||
func (bq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (*BackfillTask, error) {
|
||||
if len(backfillTypes) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
//bq.backfillQueryLock.Lock()
|
||||
//defer bq.backfillQueryLock.Unlock()
|
||||
|
||||
query := fmt.Sprintf(getNextBackfillTaskQueryTemplate, typesToString(backfillTypes))
|
||||
return bq.QueryOne(ctx, query, userID, time.Now().Add(-15*time.Minute))
|
||||
}
|
||||
|
||||
func (bq *BackfillTaskQuery) HasUnstartedOrInFlightOfType(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (has bool) {
|
||||
if len(backfillTypes) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
//bq.backfillQueryLock.Lock()
|
||||
//defer bq.backfillQueryLock.Unlock()
|
||||
|
||||
query := fmt.Sprintf(getUnstartedOrInFlightBackfillTaskQueryTemplate, typesToString(backfillTypes))
|
||||
err := bq.GetDB().QueryRow(ctx, query, userID).Scan(&has)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to check if backfill queue has jobs")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (bq *BackfillTaskQuery) DeleteAll(ctx context.Context, userID id.UserID) error {
|
||||
//bq.backfillQueryLock.Lock()
|
||||
//defer bq.backfillQueryLock.Unlock()
|
||||
return bq.Exec(ctx, deleteBackfillQueueForUserQuery, userID)
|
||||
}
|
||||
|
||||
func (bq *BackfillTaskQuery) DeleteAllForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
|
||||
//bq.backfillQueryLock.Lock()
|
||||
//defer bq.backfillQueryLock.Unlock()
|
||||
return bq.Exec(ctx, deleteBackfillQueueForPortalQuery, userID, portalKey.JID, portalKey.Receiver)
|
||||
}
|
||||
|
||||
type BackfillTask struct {
|
||||
qh *dbutil.QueryHelper[*BackfillTask]
|
||||
|
||||
QueueID int
|
||||
UserID id.UserID
|
||||
BackfillType BackfillType
|
||||
Priority int
|
||||
Portal PortalKey
|
||||
TimeStart *time.Time
|
||||
MaxBatchEvents int
|
||||
MaxTotalEvents int
|
||||
BatchDelay int
|
||||
DispatchTime *time.Time
|
||||
CompletedAt *time.Time
|
||||
}
|
||||
|
||||
func (b *BackfillTask) MarshalZerologObject(evt *zerolog.Event) {
|
||||
evt.Int("queue_id", b.QueueID).
|
||||
Stringer("user_id", b.UserID).
|
||||
Stringer("backfill_type", b.BackfillType).
|
||||
Int("priority", b.Priority).
|
||||
Stringer("portal_jid", b.Portal.JID).
|
||||
Any("time_start", b.TimeStart).
|
||||
Int("max_batch_events", b.MaxBatchEvents).
|
||||
Int("max_total_events", b.MaxTotalEvents).
|
||||
Int("batch_delay", b.BatchDelay).
|
||||
Any("dispatch_time", b.DispatchTime).
|
||||
Any("completed_at", b.CompletedAt)
|
||||
}
|
||||
|
||||
func (b *BackfillTask) String() string {
|
||||
return fmt.Sprintf(
|
||||
"BackfillTask{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}",
|
||||
b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime,
|
||||
)
|
||||
}
|
||||
|
||||
func (b *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) {
|
||||
var maxTotalEvents, batchDelay sql.NullInt32
|
||||
err := row.Scan(
|
||||
&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart,
|
||||
&b.MaxBatchEvents, &maxTotalEvents, &batchDelay,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.MaxTotalEvents = int(maxTotalEvents.Int32)
|
||||
b.BatchDelay = int(batchDelay.Int32)
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *BackfillTask) sqlVariables() []any {
|
||||
return []any{
|
||||
b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart,
|
||||
b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackfillTask) Insert(ctx context.Context) error {
|
||||
//b.db.Backfill.backfillQueryLock.Lock()
|
||||
//defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
return b.qh.GetDB().QueryRow(ctx, insertBackfillTaskQuery, b.sqlVariables()...).Scan(&b.QueueID)
|
||||
}
|
||||
|
||||
func (b *BackfillTask) MarkDispatched(ctx context.Context) error {
|
||||
//b.db.Backfill.backfillQueryLock.Lock()
|
||||
//defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
if b.QueueID == 0 {
|
||||
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
|
||||
}
|
||||
return b.qh.Exec(ctx, markBackfillTaskDispatchedQuery, time.Now(), b.QueueID)
|
||||
}
|
||||
|
||||
func (b *BackfillTask) MarkDone(ctx context.Context) error {
|
||||
//b.db.Backfill.backfillQueryLock.Lock()
|
||||
//defer b.db.Backfill.backfillQueryLock.Unlock()
|
||||
|
||||
if b.QueueID == 0 {
|
||||
return fmt.Errorf("can't mark backfill as dispatched without queue_id")
|
||||
}
|
||||
return b.qh.Exec(ctx, markBackfillTaskDoneQuery, time.Now(), b.QueueID)
|
||||
}
|
94
database/backfillstate.go
Normal file
94
database/backfillstate.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type BackfillStateQuery struct {
|
||||
*dbutil.QueryHelper[*BackfillState]
|
||||
}
|
||||
|
||||
func newBackfillState(qh *dbutil.QueryHelper[*BackfillState]) *BackfillState {
|
||||
return &BackfillState{qh: qh}
|
||||
}
|
||||
|
||||
func (bq *BackfillStateQuery) NewBackfillState(userID id.UserID, portalKey PortalKey) *BackfillState {
|
||||
return &BackfillState{
|
||||
qh: bq.QueryHelper,
|
||||
|
||||
UserID: userID,
|
||||
Portal: portalKey,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
getBackfillStateQuery = `
|
||||
SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts
|
||||
FROM backfill_state
|
||||
WHERE user_mxid=$1
|
||||
AND portal_jid=$2
|
||||
AND portal_receiver=$3
|
||||
`
|
||||
upsertBackfillStateQuery = `
|
||||
INSERT INTO backfill_state
|
||||
(user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver)
|
||||
DO UPDATE SET
|
||||
processing_batch=EXCLUDED.processing_batch,
|
||||
backfill_complete=EXCLUDED.backfill_complete,
|
||||
first_expected_ts=EXCLUDED.first_expected_ts
|
||||
`
|
||||
)
|
||||
|
||||
func (bq *BackfillStateQuery) GetBackfillState(ctx context.Context, userID id.UserID, portalKey PortalKey) (*BackfillState, error) {
|
||||
return bq.QueryOne(ctx, getBackfillStateQuery, userID, portalKey.JID, portalKey.Receiver)
|
||||
}
|
||||
|
||||
type BackfillState struct {
|
||||
qh *dbutil.QueryHelper[*BackfillState]
|
||||
|
||||
UserID id.UserID
|
||||
Portal PortalKey
|
||||
ProcessingBatch bool
|
||||
BackfillComplete bool
|
||||
FirstExpectedTimestamp uint64
|
||||
}
|
||||
|
||||
func (b *BackfillState) Scan(row dbutil.Scannable) (*BackfillState, error) {
|
||||
return dbutil.ValueOrErr(b, row.Scan(
|
||||
&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp,
|
||||
))
|
||||
}
|
||||
|
||||
func (b *BackfillState) sqlVariables() []any {
|
||||
return []any{b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp}
|
||||
}
|
||||
|
||||
func (b *BackfillState) Upsert(ctx context.Context) error {
|
||||
return b.qh.Exec(ctx, upsertBackfillStateQuery, b.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (b *BackfillState) SetProcessingBatch(ctx context.Context, processing bool) error {
|
||||
b.ProcessingBatch = processing
|
||||
return b.Upsert(ctx)
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -26,7 +26,6 @@ import (
|
|||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/whatsmeow/store"
|
||||
"go.mau.fi/whatsmeow/store/sqlstore"
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix-whatsapp/database/upgrades"
|
||||
)
|
||||
|
@ -45,51 +44,28 @@ type Database struct {
|
|||
Reaction *ReactionQuery
|
||||
|
||||
DisappearingMessage *DisappearingMessageQuery
|
||||
Backfill *BackfillQuery
|
||||
BackfillQueue *BackfillTaskQuery
|
||||
BackfillState *BackfillStateQuery
|
||||
HistorySync *HistorySyncQuery
|
||||
MediaBackfillRequest *MediaBackfillRequestQuery
|
||||
}
|
||||
|
||||
func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
|
||||
db := &Database{Database: baseDB}
|
||||
func New(db *dbutil.Database) *Database {
|
||||
db.UpgradeTable = upgrades.Table
|
||||
db.User = &UserQuery{
|
||||
db: db,
|
||||
log: log.Sub("User"),
|
||||
return &Database{
|
||||
Database: db,
|
||||
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
|
||||
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
|
||||
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
|
||||
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
|
||||
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
|
||||
|
||||
DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)},
|
||||
BackfillQueue: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)},
|
||||
BackfillState: &BackfillStateQuery{dbutil.MakeQueryHelper(db, newBackfillState)},
|
||||
HistorySync: &HistorySyncQuery{dbutil.MakeQueryHelper(db, newHistorySyncConversation)},
|
||||
MediaBackfillRequest: &MediaBackfillRequestQuery{dbutil.MakeQueryHelper(db, newMediaBackfillRequest)},
|
||||
}
|
||||
db.Portal = &PortalQuery{
|
||||
db: db,
|
||||
log: log.Sub("Portal"),
|
||||
}
|
||||
db.Puppet = &PuppetQuery{
|
||||
db: db,
|
||||
log: log.Sub("Puppet"),
|
||||
}
|
||||
db.Message = &MessageQuery{
|
||||
db: db,
|
||||
log: log.Sub("Message"),
|
||||
}
|
||||
db.Reaction = &ReactionQuery{
|
||||
db: db,
|
||||
log: log.Sub("Reaction"),
|
||||
}
|
||||
db.DisappearingMessage = &DisappearingMessageQuery{
|
||||
db: db,
|
||||
log: log.Sub("DisappearingMessage"),
|
||||
}
|
||||
db.Backfill = &BackfillQuery{
|
||||
db: db,
|
||||
log: log.Sub("Backfill"),
|
||||
}
|
||||
db.HistorySync = &HistorySyncQuery{
|
||||
db: db,
|
||||
log: log.Sub("HistorySync"),
|
||||
}
|
||||
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
|
||||
db: db,
|
||||
log: log.Sub("MediaBackfillRequest"),
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,32 +17,29 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
type DisappearingMessageQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*DisappearingMessage]
|
||||
}
|
||||
|
||||
func (dmq *DisappearingMessageQuery) New() *DisappearingMessage {
|
||||
func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
|
||||
return &DisappearingMessage{
|
||||
db: dmq.db,
|
||||
log: dmq.log,
|
||||
qh: qh,
|
||||
}
|
||||
}
|
||||
|
||||
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
|
||||
dm := &DisappearingMessage{
|
||||
db: dmq.db,
|
||||
log: dmq.log,
|
||||
qh: dmq.QueryHelper,
|
||||
|
||||
RoomID: roomID,
|
||||
EventID: eventID,
|
||||
ExpireIn: expireIn,
|
||||
|
@ -55,22 +52,17 @@ const (
|
|||
getAllScheduledDisappearingMessagesQuery = `
|
||||
SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1
|
||||
`
|
||||
insertDisappearingMessageQuery = `INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`
|
||||
updateDisappearingMessageExpiryQuery = "UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3"
|
||||
deleteDisappearingMessageQuery = "DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2"
|
||||
)
|
||||
|
||||
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(duration time.Duration) (messages []*DisappearingMessage) {
|
||||
rows, err := dmq.db.Query(getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
for rows.Next() {
|
||||
messages = append(messages, dmq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
|
||||
return dmq.QueryMany(ctx, getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli())
|
||||
}
|
||||
|
||||
type DisappearingMessage struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*DisappearingMessage]
|
||||
|
||||
RoomID id.RoomID
|
||||
EventID id.EventID
|
||||
|
@ -78,50 +70,33 @@ type DisappearingMessage struct {
|
|||
ExpireAt time.Time
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
|
||||
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
|
||||
var expireIn int64
|
||||
var expireAt sql.NullInt64
|
||||
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
msg.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
msg.ExpireIn = time.Duration(expireIn) * time.Millisecond
|
||||
if expireAt.Valid {
|
||||
msg.ExpireAt = time.UnixMilli(expireAt.Int64)
|
||||
}
|
||||
return msg
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) Insert(txn dbutil.Execable) {
|
||||
if txn == nil {
|
||||
txn = msg.db
|
||||
}
|
||||
var expireAt sql.NullInt64
|
||||
if !msg.ExpireAt.IsZero() {
|
||||
expireAt.Valid = true
|
||||
expireAt.Int64 = msg.ExpireAt.UnixMilli()
|
||||
}
|
||||
_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`,
|
||||
msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)
|
||||
}
|
||||
func (msg *DisappearingMessage) sqlVariables() []any {
|
||||
return []any{msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), dbutil.UnixMilliPtr(msg.ExpireAt)}
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) StartTimer() {
|
||||
func (msg *DisappearingMessage) Insert(ctx context.Context) error {
|
||||
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) StartTimer(ctx context.Context) error {
|
||||
msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second)
|
||||
_, err := msg.db.Exec("UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err)
|
||||
}
|
||||
return msg.qh.Exec(ctx, updateDisappearingMessageExpiryQuery, msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
|
||||
}
|
||||
|
||||
func (msg *DisappearingMessage) Delete() {
|
||||
_, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2", msg.RoomID, msg.EventID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err)
|
||||
}
|
||||
func (msg *DisappearingMessage) Delete(ctx context.Context) error {
|
||||
return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.RoomID, msg.EventID)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
|
||||
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,8 +17,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -26,23 +25,19 @@ import (
|
|||
"go.mau.fi/util/dbutil"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"google.golang.org/protobuf/proto"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type HistorySyncQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*HistorySyncConversation]
|
||||
}
|
||||
|
||||
type HistorySyncConversation struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*HistorySyncConversation]
|
||||
|
||||
UserID id.UserID
|
||||
ConversationID string
|
||||
PortalKey *PortalKey
|
||||
PortalKey PortalKey
|
||||
LastMessageTimestamp time.Time
|
||||
MuteEndTime time.Time
|
||||
Archived bool
|
||||
|
@ -54,18 +49,16 @@ type HistorySyncConversation struct {
|
|||
UnreadCount uint32
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation {
|
||||
func newHistorySyncConversation(qh *dbutil.QueryHelper[*HistorySyncConversation]) *HistorySyncConversation {
|
||||
return &HistorySyncConversation{
|
||||
db: hsq.db,
|
||||
log: hsq.log,
|
||||
PortalKey: &PortalKey{},
|
||||
qh: qh,
|
||||
}
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) NewConversationWithValues(
|
||||
userID id.UserID,
|
||||
conversationID string,
|
||||
portalKey *PortalKey,
|
||||
portalKey PortalKey,
|
||||
lastMessageTimestamp,
|
||||
muteEndTime uint64,
|
||||
archived bool,
|
||||
|
@ -74,10 +67,10 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
|
|||
endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType,
|
||||
ephemeralExpiration *uint32,
|
||||
markedAsUnread bool,
|
||||
unreadCount uint32) *HistorySyncConversation {
|
||||
unreadCount uint32,
|
||||
) *HistorySyncConversation {
|
||||
return &HistorySyncConversation{
|
||||
db: hsq.db,
|
||||
log: hsq.log,
|
||||
qh: hsq.QueryHelper,
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
PortalKey: portalKey,
|
||||
|
@ -94,6 +87,17 @@ func (hsq *HistorySyncQuery) NewConversationWithValues(
|
|||
}
|
||||
|
||||
const (
|
||||
upsertHistorySyncConversationQuery = `
|
||||
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
ON CONFLICT (user_mxid, conversation_id)
|
||||
DO UPDATE SET
|
||||
last_message_timestamp=CASE
|
||||
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
|
||||
ELSE history_sync_conversation.last_message_timestamp
|
||||
END,
|
||||
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
|
||||
`
|
||||
getNMostRecentConversations = `
|
||||
SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count
|
||||
FROM history_sync_conversation
|
||||
|
@ -108,24 +112,19 @@ const (
|
|||
AND portal_jid=$2
|
||||
AND portal_receiver=$3
|
||||
`
|
||||
deleteAllConversationsQuery = "DELETE FROM history_sync_conversation WHERE user_mxid=$1"
|
||||
deleteHistorySyncConversationQuery = `
|
||||
DELETE FROM history_sync_conversation
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
`
|
||||
)
|
||||
|
||||
func (hsc *HistorySyncConversation) Upsert() {
|
||||
_, err := hsc.db.Exec(`
|
||||
INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
ON CONFLICT (user_mxid, conversation_id)
|
||||
DO UPDATE SET
|
||||
last_message_timestamp=CASE
|
||||
WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp
|
||||
ELSE history_sync_conversation.last_message_timestamp
|
||||
END,
|
||||
end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type
|
||||
`,
|
||||
func (hsc *HistorySyncConversation) sqlVariables() []any {
|
||||
return []any{
|
||||
hsc.UserID,
|
||||
hsc.ConversationID,
|
||||
hsc.PortalKey.JID.String(),
|
||||
hsc.PortalKey.Receiver.String(),
|
||||
hsc.PortalKey.JID,
|
||||
hsc.PortalKey.Receiver,
|
||||
hsc.LastMessageTimestamp,
|
||||
hsc.Archived,
|
||||
hsc.Pinned,
|
||||
|
@ -134,14 +133,16 @@ func (hsc *HistorySyncConversation) Upsert() {
|
|||
hsc.EndOfHistoryTransferType,
|
||||
hsc.EphemeralExpiration,
|
||||
hsc.MarkedAsUnread,
|
||||
hsc.UnreadCount)
|
||||
if err != nil {
|
||||
hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err)
|
||||
hsc.UnreadCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
|
||||
err := row.Scan(
|
||||
func (hsc *HistorySyncConversation) Upsert(ctx context.Context) error {
|
||||
return hsc.qh.Exec(ctx, upsertHistorySyncConversationQuery, hsc.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConversation, error) {
|
||||
return dbutil.ValueOrErr(hsc, row.Scan(
|
||||
&hsc.UserID,
|
||||
&hsc.ConversationID,
|
||||
&hsc.PortalKey.JID,
|
||||
|
@ -154,69 +155,59 @@ func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConve
|
|||
&hsc.EndOfHistoryTransferType,
|
||||
&hsc.EphemeralExpiration,
|
||||
&hsc.MarkedAsUnread,
|
||||
&hsc.UnreadCount)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
hsc.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return hsc
|
||||
&hsc.UnreadCount,
|
||||
))
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
|
||||
func (hsq *HistorySyncQuery) GetRecentConversations(ctx context.Context, userID id.UserID, n int) ([]*HistorySyncConversation, error) {
|
||||
nPtr := &n
|
||||
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
|
||||
if n < 0 && hsq.db.Dialect == dbutil.Postgres {
|
||||
if n < 0 && hsq.GetDB().Dialect == dbutil.Postgres {
|
||||
nPtr = nil
|
||||
}
|
||||
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)
|
||||
defer rows.Close()
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
for rows.Next() {
|
||||
conversations = append(conversations, hsq.NewConversation().Scan(rows))
|
||||
}
|
||||
return
|
||||
return hsq.QueryMany(ctx, getNMostRecentConversations, userID, nPtr)
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) {
|
||||
rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
|
||||
defer rows.Close()
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
if rows.Next() {
|
||||
conversation = hsq.NewConversation().Scan(rows)
|
||||
}
|
||||
return
|
||||
func (hsq *HistorySyncQuery) GetConversation(ctx context.Context, userID id.UserID, portalKey PortalKey) (*HistorySyncConversation, error) {
|
||||
return hsq.QueryOne(ctx, getConversationByPortal, userID, portalKey.JID, portalKey.Receiver)
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) {
|
||||
_, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID)
|
||||
if err != nil {
|
||||
hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err)
|
||||
}
|
||||
func (hsq *HistorySyncQuery) DeleteAllConversations(ctx context.Context, userID id.UserID) error {
|
||||
return hsq.Exec(ctx, deleteAllConversationsQuery, userID)
|
||||
}
|
||||
|
||||
const (
|
||||
getMessagesBetween = `
|
||||
insertHistorySyncMessageQuery = `
|
||||
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
|
||||
`
|
||||
getHistorySyncMessagesBetweenQueryTemplate = `
|
||||
SELECT data FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
%s
|
||||
ORDER BY timestamp DESC
|
||||
%s
|
||||
`
|
||||
deleteMessagesBetweenExclusive = `
|
||||
deleteHistorySyncMessagesBetweenExclusiveQuery = `
|
||||
DELETE FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4
|
||||
`
|
||||
deleteAllHistorySyncMessagesQuery = "DELETE FROM history_sync_message WHERE user_mxid=$1"
|
||||
deleteHistorySyncMessagesForPortalQuery = `
|
||||
DELETE FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
`
|
||||
conversationHasHistorySyncMessagesQuery = `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
)
|
||||
`
|
||||
)
|
||||
|
||||
type HistorySyncMessage struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
hsq *HistorySyncQuery
|
||||
|
||||
UserID id.UserID
|
||||
ConversationID string
|
||||
|
@ -231,8 +222,8 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
|
|||
return nil, err
|
||||
}
|
||||
return &HistorySyncMessage{
|
||||
db: hsq.db,
|
||||
log: hsq.log,
|
||||
hsq: hsq,
|
||||
|
||||
UserID: userID,
|
||||
ConversationID: conversationID,
|
||||
MessageID: messageID,
|
||||
|
@ -241,18 +232,27 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (hsm *HistorySyncMessage) Insert() error {
|
||||
_, err := hsm.db.Exec(`
|
||||
INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING
|
||||
`, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
|
||||
return err
|
||||
func (hsm *HistorySyncMessage) Insert(ctx context.Context) error {
|
||||
return hsm.hsq.Exec(ctx, insertHistorySyncMessageQuery, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now())
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) {
|
||||
func scanWebMessageInfo(rows dbutil.Scannable) (*waProto.WebMessageInfo, error) {
|
||||
var msgData []byte
|
||||
err := rows.Scan(&msgData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var historySyncMsg waProto.HistorySyncMsg
|
||||
err = proto.Unmarshal(msgData, &historySyncMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message: %w", err)
|
||||
}
|
||||
return historySyncMsg.GetMessage(), nil
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) ([]*waProto.WebMessageInfo, error) {
|
||||
whereClauses := ""
|
||||
args := []interface{}{userID, conversationID}
|
||||
args := []any{userID, conversationID}
|
||||
argNum := 3
|
||||
if startTime != nil {
|
||||
whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum)
|
||||
|
@ -268,80 +268,35 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID
|
|||
if limit > 0 {
|
||||
limitClause = fmt.Sprintf("LIMIT %d", limit)
|
||||
}
|
||||
query := fmt.Sprintf(getHistorySyncMessagesBetweenQueryTemplate, whereClauses, limitClause)
|
||||
|
||||
rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...)
|
||||
defer rows.Close()
|
||||
if err != nil || rows == nil {
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
hsq.log.Warnfln("Failed to query messages between range: %v", err)
|
||||
}
|
||||
return nil
|
||||
return dbutil.ConvertRowFn[*waProto.WebMessageInfo](scanWebMessageInfo).
|
||||
NewRowIter(hsq.GetDB().Query(ctx, query, args...)).
|
||||
AsList()
|
||||
}
|
||||
|
||||
var msgData []byte
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&msgData)
|
||||
if err != nil {
|
||||
hsq.log.Errorfln("Database scan failed: %v", err)
|
||||
continue
|
||||
}
|
||||
var historySyncMsg waProto.HistorySyncMsg
|
||||
err = proto.Unmarshal(msgData, &historySyncMsg)
|
||||
if err != nil {
|
||||
hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err)
|
||||
continue
|
||||
}
|
||||
messages = append(messages, historySyncMsg.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
|
||||
func (hsq *HistorySyncQuery) DeleteMessages(ctx context.Context, userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error {
|
||||
newest := messages[0]
|
||||
beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0)
|
||||
oldest := messages[len(messages)-1]
|
||||
afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0)
|
||||
_, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS)
|
||||
return err
|
||||
return hsq.Exec(ctx, deleteHistorySyncMessagesBetweenExclusiveQuery, userID, conversationID, beforeTS, afterTS)
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) {
|
||||
_, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID)
|
||||
if err != nil {
|
||||
hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err)
|
||||
}
|
||||
func (hsq *HistorySyncQuery) DeleteAllMessages(ctx context.Context, userID id.UserID) error {
|
||||
return hsq.Exec(ctx, deleteAllHistorySyncMessagesQuery, userID)
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) {
|
||||
_, err := hsq.db.Exec(`
|
||||
DELETE FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
`, userID, portalKey.JID)
|
||||
if err != nil {
|
||||
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err)
|
||||
}
|
||||
func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error {
|
||||
return hsq.Exec(ctx, deleteHistorySyncMessagesForPortalQuery, userID, portalKey.JID)
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) ConversationHasMessages(userID id.UserID, portalKey PortalKey) (exists bool) {
|
||||
err := hsq.db.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM history_sync_message
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
)
|
||||
`, userID, portalKey.JID).Scan(&exists)
|
||||
if err != nil {
|
||||
hsq.log.Warnfln("Failed to check if any messages are stored for %s/%s: %v", userID, portalKey.JID, err)
|
||||
}
|
||||
func (hsq *HistorySyncQuery) ConversationHasMessages(ctx context.Context, userID id.UserID, portalKey PortalKey) (exists bool, err error) {
|
||||
err = hsq.GetDB().QueryRow(ctx, conversationHasHistorySyncMessagesQuery, userID, portalKey.JID).Scan(&exists)
|
||||
return
|
||||
}
|
||||
|
||||
func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) {
|
||||
func (hsq *HistorySyncQuery) DeleteConversation(ctx context.Context, userID id.UserID, jid string) error {
|
||||
// This will also clear history_sync_message as there's a foreign key constraint
|
||||
_, err := hsq.db.Exec(`
|
||||
DELETE FROM history_sync_conversation
|
||||
WHERE user_mxid=$1 AND conversation_id=$2
|
||||
`, userID, jid)
|
||||
if err != nil {
|
||||
hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err)
|
||||
}
|
||||
return hsq.Exec(ctx, deleteHistorySyncConversationQuery, userID, jid)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan, Sumner Evans
|
||||
// Copyright (C) 2024 Tulir Asokan, Sumner Evans
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,14 +17,12 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"context"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.mau.fi/util/dbutil"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
type MediaBackfillRequestStatus int
|
||||
|
@ -36,34 +34,46 @@ const (
|
|||
)
|
||||
|
||||
type MediaBackfillRequestQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*MediaBackfillRequest]
|
||||
}
|
||||
|
||||
type MediaBackfillRequest struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
const (
|
||||
getAllMediaBackfillRequestsForUserQuery = `
|
||||
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
|
||||
FROM media_backfill_requests
|
||||
WHERE user_mxid=$1
|
||||
AND status=0
|
||||
`
|
||||
deleteAllMediaBackfillRequestsForUserQuery = "DELETE FROM media_backfill_requests WHERE user_mxid=$1"
|
||||
upsertMediaBackfillRequestQuery = `
|
||||
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
|
||||
DO UPDATE SET
|
||||
media_key=excluded.media_key,
|
||||
status=excluded.status,
|
||||
error=excluded.error
|
||||
`
|
||||
)
|
||||
|
||||
UserID id.UserID
|
||||
PortalKey *PortalKey
|
||||
EventID id.EventID
|
||||
MediaKey []byte
|
||||
Status MediaBackfillRequestStatus
|
||||
Error string
|
||||
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(ctx context.Context, userID id.UserID) ([]*MediaBackfillRequest, error) {
|
||||
return mbrq.QueryMany(ctx, getAllMediaBackfillRequestsForUserQuery, userID)
|
||||
}
|
||||
|
||||
func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest {
|
||||
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(ctx context.Context, userID id.UserID) error {
|
||||
return mbrq.Exec(ctx, deleteAllMediaBackfillRequestsForUserQuery, userID)
|
||||
}
|
||||
|
||||
func newMediaBackfillRequest(qh *dbutil.QueryHelper[*MediaBackfillRequest]) *MediaBackfillRequest {
|
||||
return &MediaBackfillRequest{
|
||||
db: mbrq.db,
|
||||
log: mbrq.log,
|
||||
PortalKey: &PortalKey{},
|
||||
qh: qh,
|
||||
}
|
||||
}
|
||||
|
||||
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
|
||||
func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest {
|
||||
return &MediaBackfillRequest{
|
||||
db: mbrq.db,
|
||||
log: mbrq.log,
|
||||
qh: mbrq.QueryHelper,
|
||||
|
||||
UserID: userID,
|
||||
PortalKey: portalKey,
|
||||
EventID: eventID,
|
||||
|
@ -72,62 +82,25 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID
|
|||
}
|
||||
}
|
||||
|
||||
const (
|
||||
getMediaBackfillRequestsForUser = `
|
||||
SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error
|
||||
FROM media_backfill_requests
|
||||
WHERE user_mxid=$1
|
||||
AND status=0
|
||||
`
|
||||
)
|
||||
type MediaBackfillRequest struct {
|
||||
qh *dbutil.QueryHelper[*MediaBackfillRequest]
|
||||
|
||||
func (mbr *MediaBackfillRequest) Upsert() {
|
||||
_, err := mbr.db.Exec(`
|
||||
INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id)
|
||||
DO UPDATE SET
|
||||
media_key=EXCLUDED.media_key,
|
||||
status=EXCLUDED.status,
|
||||
error=EXCLUDED.error`,
|
||||
mbr.UserID,
|
||||
mbr.PortalKey.JID.String(),
|
||||
mbr.PortalKey.Receiver.String(),
|
||||
mbr.EventID,
|
||||
mbr.MediaKey,
|
||||
mbr.Status,
|
||||
mbr.Error)
|
||||
if err != nil {
|
||||
mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err)
|
||||
}
|
||||
UserID id.UserID
|
||||
PortalKey PortalKey
|
||||
EventID id.EventID
|
||||
MediaKey []byte
|
||||
Status MediaBackfillRequestStatus
|
||||
Error string
|
||||
}
|
||||
|
||||
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
|
||||
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
mbr.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return mbr
|
||||
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) (*MediaBackfillRequest, error) {
|
||||
return dbutil.ValueOrErr(mbr, row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error))
|
||||
}
|
||||
|
||||
func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) {
|
||||
rows, err := mbrq.db.Query(getMediaBackfillRequestsForUser, userID)
|
||||
defer rows.Close()
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
for rows.Next() {
|
||||
requests = append(requests, mbrq.newMediaBackfillRequest().Scan(rows))
|
||||
}
|
||||
return
|
||||
func (mbr *MediaBackfillRequest) sqlVariables() []any {
|
||||
return []any{mbr.UserID, mbr.PortalKey.JID, mbr.PortalKey.Receiver, mbr.EventID, mbr.MediaKey, mbr.Status, mbr.Error}
|
||||
}
|
||||
|
||||
func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) {
|
||||
_, err := mbrq.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID)
|
||||
if err != nil {
|
||||
mbrq.log.Warnfln("Failed to delete media backfill requests for %s: %v", userID, err)
|
||||
}
|
||||
func (mbr *MediaBackfillRequest) Upsert(ctx context.Context) error {
|
||||
return mbr.qh.Exec(ctx, upsertMediaBackfillRequestQuery, mbr.sqlVariables()...)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,29 +17,22 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type MessageQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*Message]
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) New() *Message {
|
||||
return &Message{
|
||||
db: mq.db,
|
||||
log: mq.log,
|
||||
}
|
||||
func newMessage(qh *dbutil.QueryHelper[*Message]) *Message {
|
||||
return &Message{qh: qh}
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -67,60 +60,47 @@ const (
|
|||
SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message
|
||||
WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC
|
||||
`
|
||||
insertMessageQuery = `
|
||||
INSERT INTO message
|
||||
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`
|
||||
markMessageSentQuery = "UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4"
|
||||
updateMessageMXIDQuery = "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6"
|
||||
deleteMessageQuery = "DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3"
|
||||
)
|
||||
|
||||
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
|
||||
rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver)
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
for rows.Next() {
|
||||
messages = append(messages, mq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
func (mq *MessageQuery) GetAll(ctx context.Context, chat PortalKey) ([]*Message, error) {
|
||||
return mq.QueryMany(ctx, getAllMessagesQuery, chat.JID, chat.Receiver)
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message {
|
||||
return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid))
|
||||
func (mq *MessageQuery) GetByJID(ctx context.Context, chat PortalKey, jid types.MessageID) (*Message, error) {
|
||||
return mq.QueryOne(ctx, getMessageByJIDQuery, chat.JID, chat.Receiver, jid)
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
|
||||
return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid))
|
||||
func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) {
|
||||
return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid)
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
|
||||
return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second))
|
||||
func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat PortalKey) (*Message, error) {
|
||||
return mq.GetLastInChatBefore(ctx, chat, time.Now().Add(60*time.Second))
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message {
|
||||
msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()))
|
||||
if msg == nil || msg.Timestamp.IsZero() {
|
||||
func (mq *MessageQuery) GetLastInChatBefore(ctx context.Context, chat PortalKey, maxTimestamp time.Time) (*Message, error) {
|
||||
msg, err := mq.QueryOne(ctx, getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())
|
||||
if msg != nil && msg.Timestamp.IsZero() {
|
||||
// Old db, we don't know what the last message is.
|
||||
return nil
|
||||
msg = nil
|
||||
}
|
||||
return msg
|
||||
return msg, err
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message {
|
||||
return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver))
|
||||
func (mq *MessageQuery) GetFirstInChat(ctx context.Context, chat PortalKey) (*Message, error) {
|
||||
return mq.QueryOne(ctx, getFirstMessageInChatQuery, chat.JID, chat.Receiver)
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) {
|
||||
rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
for rows.Next() {
|
||||
messages = append(messages, mq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) maybeScan(row *sql.Row) *Message {
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return mq.New().Scan(row)
|
||||
func (mq *MessageQuery) GetMessagesBetween(ctx context.Context, chat PortalKey, minTimestamp, maxTimestamp time.Time) ([]*Message, error) {
|
||||
return mq.QueryMany(ctx, getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix())
|
||||
}
|
||||
|
||||
type MessageErrorType string
|
||||
|
@ -144,8 +124,7 @@ const (
|
|||
)
|
||||
|
||||
type Message struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*Message]
|
||||
|
||||
Chat PortalKey
|
||||
JID types.MessageID
|
||||
|
@ -172,76 +151,49 @@ func (msg *Message) IsFakeJID() bool {
|
|||
|
||||
const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s"
|
||||
|
||||
func (msg *Message) Scan(row dbutil.Scannable) *Message {
|
||||
func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) {
|
||||
var ts int64
|
||||
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
msg.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") {
|
||||
_, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID)
|
||||
if err != nil {
|
||||
msg.log.Errorln("Parsing gallery MXID failed:", err)
|
||||
return nil, fmt.Errorf("failed to parse gallery MXID: %w", err)
|
||||
}
|
||||
}
|
||||
if ts != 0 {
|
||||
msg.Timestamp = time.Unix(ts, 0)
|
||||
}
|
||||
return msg
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (msg *Message) Insert(txn dbutil.Execable) {
|
||||
if txn == nil {
|
||||
txn = msg.db
|
||||
}
|
||||
var sender interface{} = msg.Sender
|
||||
// Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events)
|
||||
if msg.Sender.IsEmpty() {
|
||||
sender = ""
|
||||
}
|
||||
func (msg *Message) sqlVariables() []any {
|
||||
mxid := msg.MXID.String()
|
||||
if msg.GalleryPart != 0 {
|
||||
mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid)
|
||||
}
|
||||
_, err := txn.Exec(`
|
||||
INSERT INTO message
|
||||
(chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
`, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
|
||||
}
|
||||
return []any{msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, msg.Sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID}
|
||||
}
|
||||
|
||||
func (msg *Message) MarkSent(ts time.Time) {
|
||||
func (msg *Message) Insert(ctx context.Context) error {
|
||||
return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (msg *Message) MarkSent(ctx context.Context, ts time.Time) error {
|
||||
msg.Sent = true
|
||||
msg.Timestamp = ts
|
||||
_, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
|
||||
}
|
||||
return msg.qh.Exec(ctx, markMessageSentQuery, ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
}
|
||||
|
||||
func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) {
|
||||
if txn == nil {
|
||||
txn = msg.db
|
||||
}
|
||||
func (msg *Message) UpdateMXID(ctx context.Context, mxid id.EventID, newType MessageType, newError MessageErrorType) error {
|
||||
msg.MXID = mxid
|
||||
msg.Type = newType
|
||||
msg.Error = newError
|
||||
_, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6",
|
||||
mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err)
|
||||
}
|
||||
return msg.qh.Exec(ctx, updateMessageMXIDQuery, mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
}
|
||||
|
||||
func (msg *Message) Delete() {
|
||||
_, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
if err != nil {
|
||||
msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err)
|
||||
}
|
||||
func (msg *Message) Delete(ctx context.Context) error {
|
||||
return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,28 +17,56 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) {
|
||||
var hash []byte
|
||||
err = rows.Scan(&id, &hash)
|
||||
if err != nil {
|
||||
// return below
|
||||
} else if len(hash) != 32 {
|
||||
err = fmt.Errorf("unexpected hash length %d", len(hash))
|
||||
} else {
|
||||
hashArr = *(*[32]byte)(hash)
|
||||
const (
|
||||
bulkPutPollOptionsQuery = "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
|
||||
bulkPutPollOptionsQueryTemplate = "($1, $%d, $%d)"
|
||||
bulkPutPollOptionsQueryPlaceholder = "($1, $2, $3)"
|
||||
getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
|
||||
getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
|
||||
getPollOptionQuerySQLiteArrayTemplate = " IN (%s)"
|
||||
getPollOptionQueryArrayPlaceholder = " = ANY($2)"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, "meow") == bulkPutPollOptionsQuery {
|
||||
panic("Bulk insert query placeholder not found")
|
||||
}
|
||||
if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
|
||||
panic("Array select query placeholder not found")
|
||||
}
|
||||
if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery {
|
||||
panic("Array select query placeholder not found")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
|
||||
query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)"
|
||||
type pollOption struct {
|
||||
id string
|
||||
hash [32]byte
|
||||
}
|
||||
|
||||
func scanPollOption(rows dbutil.Scannable) (*pollOption, error) {
|
||||
var hash []byte
|
||||
var id string
|
||||
err := rows.Scan(&id, &hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(hash) != 32 {
|
||||
return nil, fmt.Errorf("unexpected hash length %d", len(hash))
|
||||
} else {
|
||||
return &pollOption{id: id, hash: [32]byte(hash)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string) error {
|
||||
args := make([]any, len(opts)*2+1)
|
||||
placeholders := make([]string, len(opts))
|
||||
args[0] = msg.MXID
|
||||
|
@ -47,72 +75,47 @@ func (msg *Message) PutPollOptions(opts map[[32]byte]string) {
|
|||
args[i*2+1] = id
|
||||
hashCopy := hash
|
||||
args[i*2+2] = hashCopy[:]
|
||||
placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3)
|
||||
placeholders[i] = fmt.Sprintf(bulkPutPollOptionsQueryTemplate, i*2+2, i*2+3)
|
||||
i++
|
||||
}
|
||||
query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ","))
|
||||
_, err := msg.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err)
|
||||
}
|
||||
query := strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, strings.Join(placeholders, ","))
|
||||
return msg.qh.Exec(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string {
|
||||
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)"
|
||||
func getPollOptions[LookupKey any, Key comparable, Value any](
|
||||
ctx context.Context,
|
||||
msg *Message,
|
||||
query string,
|
||||
things []LookupKey,
|
||||
getKeyValue func(option *pollOption) (Key, Value),
|
||||
) (map[Key]Value, error) {
|
||||
var args []any
|
||||
if msg.db.Dialect == dbutil.Postgres {
|
||||
args = []any{msg.MXID, pq.Array(hashes)}
|
||||
if msg.qh.GetDB().Dialect == dbutil.Postgres {
|
||||
args = []any{msg.MXID, pq.Array(things)}
|
||||
} else {
|
||||
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ",")))
|
||||
args = make([]any, len(hashes)+1)
|
||||
query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ",")))
|
||||
args = make([]any, len(things)+1)
|
||||
args[0] = msg.MXID
|
||||
for i, hash := range hashes {
|
||||
args[i+1] = hash
|
||||
for i, thing := range things {
|
||||
args[i+1] = thing
|
||||
}
|
||||
}
|
||||
ids := make(map[[32]byte]string, len(hashes))
|
||||
rows, err := msg.db.Query(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err)
|
||||
} else {
|
||||
for rows.Next() {
|
||||
id, hash, err := scanPollOptionMapping(rows)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err)
|
||||
break
|
||||
}
|
||||
ids[hash] = id
|
||||
}
|
||||
}
|
||||
return ids
|
||||
return dbutil.RowIterAsMap(
|
||||
dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(msg.qh.GetDB().Query(ctx, query, args...)),
|
||||
getKeyValue,
|
||||
)
|
||||
}
|
||||
|
||||
func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte {
|
||||
query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)"
|
||||
var args []any
|
||||
if msg.db.Dialect == dbutil.Postgres {
|
||||
args = []any{msg.MXID, pq.Array(ids)}
|
||||
} else {
|
||||
query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ",")))
|
||||
args = make([]any, len(ids)+1)
|
||||
args[0] = msg.MXID
|
||||
for i, id := range ids {
|
||||
args[i+1] = id
|
||||
func (msg *Message) GetPollOptionIDs(ctx context.Context, hashes [][]byte) (map[[32]byte]string, error) {
|
||||
return getPollOptions(
|
||||
ctx, msg, getPollOptionIDsByHashesQuery, hashes,
|
||||
func(t *pollOption) ([32]byte, string) { return t.hash, t.id },
|
||||
)
|
||||
}
|
||||
}
|
||||
hashes := make(map[string][32]byte, len(ids))
|
||||
rows, err := msg.db.Query(query, args...)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err)
|
||||
} else {
|
||||
for rows.Next() {
|
||||
id, hash, err := scanPollOptionMapping(rows)
|
||||
if err != nil {
|
||||
msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err)
|
||||
break
|
||||
}
|
||||
hashes[id] = hash
|
||||
}
|
||||
}
|
||||
return hashes
|
||||
|
||||
func (msg *Message) GetPollOptionHashes(ctx context.Context, ids []string) (map[string][32]byte, error) {
|
||||
return getPollOptions(
|
||||
ctx, msg, getPollOptionHashesByIDsQuery, ids,
|
||||
func(t *pollOption) (string, [32]byte) { return t.id, t.hash },
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,15 +17,14 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
type PortalKey struct {
|
||||
|
@ -53,90 +52,89 @@ func (key PortalKey) String() string {
|
|||
}
|
||||
|
||||
type PortalQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*Portal]
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) New() *Portal {
|
||||
func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal {
|
||||
return &Portal{
|
||||
db: pq.db,
|
||||
log: pq.log,
|
||||
qh: qh,
|
||||
}
|
||||
}
|
||||
|
||||
const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
|
||||
|
||||
func (pq *PortalQuery) GetAll() []*Portal {
|
||||
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
|
||||
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
|
||||
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
|
||||
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal {
|
||||
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
|
||||
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
|
||||
receiver = receiver.ToNonAD()
|
||||
rows, err := pq.db.Query(`
|
||||
const (
|
||||
getAllPortalsQuery = `
|
||||
SELECT jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
|
||||
encrypted, last_sync, is_parent, parent_group, in_space,
|
||||
first_event_id, next_batch_id, relay_user_id, expiration_time
|
||||
FROM portal
|
||||
`
|
||||
getPortalByJIDQuery = getAllPortalsQuery + " WHERE jid=$1 AND receiver=$2"
|
||||
getPortalByMXIDQuery = getAllPortalsQuery + " WHERE mxid=$1"
|
||||
getPrivateChatsWithQuery = getAllPortalsQuery + " WHERE jid=$1"
|
||||
getPrivateChatsOfQuery = getAllPortalsQuery + " WHERE receiver=$1"
|
||||
getAllPortalsByParentGroupQuery = getAllPortalsQuery + " WHERE parent_group=$1"
|
||||
findPrivateChatPortalsNotInSpaceQuery = `
|
||||
SELECT jid FROM portal
|
||||
LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
|
||||
WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL)
|
||||
`, receiver)
|
||||
if err != nil {
|
||||
pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err)
|
||||
return
|
||||
} else if rows == nil {
|
||||
return
|
||||
`
|
||||
|
||||
insertPortalQuery = `
|
||||
INSERT INTO portal (
|
||||
jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
|
||||
encrypted, last_sync, is_parent, parent_group, in_space,
|
||||
first_event_id, next_batch_id, relay_user_id, expiration_time
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||
`
|
||||
updatePortalQuery = `
|
||||
UPDATE portal
|
||||
SET mxid=$3, name=$4, name_set=$5, topic=$6, topic_set=$7, avatar=$8, avatar_url=$9, avatar_set=$10,
|
||||
encrypted=$11, last_sync=$12, is_parent=$13, parent_group=$14, in_space=$15,
|
||||
first_event_id=$16, next_batch_id=$17, relay_user_id=$18, expiration_time=$19
|
||||
WHERE jid=$1 AND receiver=$2
|
||||
`
|
||||
clearPortalInSpaceQuery = "UPDATE portal SET in_space=false WHERE parent_group=$1"
|
||||
deletePortalQuery = "DELETE FROM portal WHERE jid=$1 AND receiver=$2"
|
||||
)
|
||||
|
||||
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getAllPortalsQuery)
|
||||
}
|
||||
for rows.Next() {
|
||||
var key PortalKey
|
||||
|
||||
func (pq *PortalQuery) GetByJID(ctx context.Context, key PortalKey) (*Portal, error) {
|
||||
return pq.QueryOne(ctx, getPortalByJIDQuery, key.JID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) {
|
||||
return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByJID(ctx context.Context, jid types.JID) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getPrivateChatsWithQuery, jid.ToNonAD())
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChats(ctx context.Context, receiver types.JID) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getPrivateChatsOfQuery, receiver.ToNonAD())
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByParentGroup(ctx context.Context, jid types.JID) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getAllPortalsByParentGroupQuery, jid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver types.JID) (keys []PortalKey, err error) {
|
||||
receiver = receiver.ToNonAD()
|
||||
scanFn := func(rows dbutil.Scannable) (key PortalKey, err error) {
|
||||
key.Receiver = receiver
|
||||
err = rows.Scan(&key.JID)
|
||||
if err == nil {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
|
||||
rows, err := pq.db.Query(query, args...)
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
portals = append(portals, pq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
|
||||
row := pq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return pq.New().Scan(row)
|
||||
return dbutil.ConvertRowFn[PortalKey](scanFn).
|
||||
NewRowIter(pq.GetDB().Query(ctx, findPrivateChatPortalsNotInSpaceQuery, receiver)).
|
||||
AsList()
|
||||
}
|
||||
|
||||
type Portal struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*Portal]
|
||||
|
||||
Key PortalKey
|
||||
MXID id.RoomID
|
||||
|
@ -161,15 +159,17 @@ type Portal struct {
|
|||
ExpirationTime uint32
|
||||
}
|
||||
|
||||
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
|
||||
func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
||||
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
|
||||
var lastSyncTs int64
|
||||
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
|
||||
err := row.Scan(
|
||||
&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet,
|
||||
&portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted,
|
||||
&lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace,
|
||||
&firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime,
|
||||
)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
portal.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
if lastSyncTs > 0 {
|
||||
portal.LastSync = time.Unix(lastSyncTs, 0)
|
||||
|
@ -182,96 +182,36 @@ func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
|
|||
portal.FirstEventID = id.EventID(firstEventID.String)
|
||||
portal.NextBatchID = id.BatchID(nextBatchID.String)
|
||||
portal.RelayUserID = id.UserID(relayUserID.String)
|
||||
return portal
|
||||
return portal, nil
|
||||
}
|
||||
|
||||
func (portal *Portal) mxidPtr() *id.RoomID {
|
||||
if len(portal.MXID) > 0 {
|
||||
return &portal.MXID
|
||||
func (portal *Portal) sqlVariables() []any {
|
||||
var lastSyncTS int64
|
||||
if !portal.LastSync.IsZero() {
|
||||
lastSyncTS = portal.LastSync.Unix()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) relayUserPtr() *id.UserID {
|
||||
if len(portal.RelayUserID) > 0 {
|
||||
return &portal.RelayUserID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) parentGroupPtr() *string {
|
||||
if !portal.ParentGroup.IsEmpty() {
|
||||
val := portal.ParentGroup.String()
|
||||
return &val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) lastSyncTs() int64 {
|
||||
if portal.LastSync.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return portal.LastSync.Unix()
|
||||
}
|
||||
|
||||
func (portal *Portal) Insert() {
|
||||
_, err := portal.db.Exec(`
|
||||
INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
|
||||
encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
|
||||
relay_user_id, expiration_time)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||
`,
|
||||
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
|
||||
portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
|
||||
portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
|
||||
portal.relayUserPtr(), portal.ExpirationTime)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
|
||||
return []any{
|
||||
portal.Key.JID, portal.Key.Receiver, dbutil.StrPtr(portal.MXID), portal.Name, portal.NameSet,
|
||||
portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted,
|
||||
lastSyncTS, portal.IsParent, dbutil.StrPtr(portal.ParentGroup.String()), portal.InSpace,
|
||||
portal.FirstEventID.String(), portal.NextBatchID.String(), dbutil.StrPtr(portal.RelayUserID), portal.ExpirationTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) Update(txn dbutil.Execable) {
|
||||
if txn == nil {
|
||||
txn = portal.db
|
||||
}
|
||||
_, err := txn.Exec(`
|
||||
UPDATE portal
|
||||
SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
|
||||
encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
|
||||
first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
|
||||
WHERE jid=$18 AND receiver=$19
|
||||
`, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
|
||||
portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
|
||||
portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
|
||||
portal.Key.JID, portal.Key.Receiver)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
|
||||
}
|
||||
func (portal *Portal) Insert(ctx context.Context) error {
|
||||
return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (portal *Portal) Delete() {
|
||||
txn, err := portal.db.Begin()
|
||||
func (portal *Portal) Update(ctx context.Context) error {
|
||||
return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (portal *Portal) Delete(ctx context.Context) error {
|
||||
return portal.qh.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
err := portal.qh.Exec(ctx, clearPortalInSpaceQuery, portal.Key.JID)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = txn.Rollback()
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
|
||||
}
|
||||
} else if err = txn.Commit(); err != nil {
|
||||
portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
|
||||
}
|
||||
}()
|
||||
_, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
|
||||
return
|
||||
}
|
||||
_, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
|
||||
return err
|
||||
}
|
||||
return portal.qh.Exec(ctx, deletePortalQuery, portal.Key.JID, portal.Key.Receiver)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,74 +17,70 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
type PuppetQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*Puppet]
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) New() *Puppet {
|
||||
func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet {
|
||||
return &Puppet{
|
||||
db: pq.db,
|
||||
log: pq.log,
|
||||
qh: qh,
|
||||
|
||||
EnablePresence: true,
|
||||
EnableReceipts: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
|
||||
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet")
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
puppets = append(puppets, pq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
const (
|
||||
getAllPuppetsQuery = `
|
||||
SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set,
|
||||
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts
|
||||
FROM puppet
|
||||
`
|
||||
getPuppetByJIDQuery = getAllPuppetsQuery + " WHERE username=$1"
|
||||
getPuppetByCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid=$1"
|
||||
getAllPuppetsWithCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid<>''"
|
||||
insertPuppetQuery = `
|
||||
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
|
||||
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`
|
||||
updatePuppetQuery = `
|
||||
UPDATE puppet
|
||||
SET avatar=$2, avatar_url=$3, avatar_set=$4, displayname=$5, name_quality=$6, name_set=$7, contact_info_set=$8,
|
||||
last_sync=$9, custom_mxid=$10, access_token=$11, next_batch=$12, enable_presence=$13, enable_receipts=$14
|
||||
WHERE username=$1
|
||||
`
|
||||
)
|
||||
|
||||
func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) {
|
||||
return pq.QueryMany(ctx, getAllPuppetsQuery)
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) Get(jid types.JID) *Puppet {
|
||||
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return pq.New().Scan(row)
|
||||
func (pq *PuppetQuery) Get(ctx context.Context, jid types.JID) (*Puppet, error) {
|
||||
return pq.QueryOne(ctx, getPuppetByJIDQuery, jid.User)
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
|
||||
row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return pq.New().Scan(row)
|
||||
func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) {
|
||||
return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid)
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) {
|
||||
rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''")
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
puppets = append(puppets, pq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) {
|
||||
return pq.QueryMany(ctx, getAllPuppetsWithCustomMXIDQuery)
|
||||
}
|
||||
|
||||
type Puppet struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*Puppet]
|
||||
|
||||
JID types.JID
|
||||
Avatar string
|
||||
|
@ -103,17 +99,14 @@ type Puppet struct {
|
|||
EnableReceipts bool
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) {
|
||||
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
|
||||
var quality, lastSync sql.NullInt64
|
||||
var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool
|
||||
var username string
|
||||
err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
puppet.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
puppet.JID = types.NewJID(username, types.DefaultUserServer)
|
||||
puppet.Displayname = displayname.String
|
||||
|
@ -131,45 +124,30 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
|||
puppet.NextBatch = nextBatch.String
|
||||
puppet.EnablePresence = enablePresence.Bool
|
||||
puppet.EnableReceipts = enableReceipts.Bool
|
||||
return puppet
|
||||
return puppet, nil
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Insert() {
|
||||
if puppet.JID.Server != types.DefaultUserServer {
|
||||
puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID)
|
||||
return
|
||||
}
|
||||
var lastSyncTs int64
|
||||
func (puppet *Puppet) sqlVariables() []any {
|
||||
var lastSyncTS int64
|
||||
if !puppet.LastSync.IsZero() {
|
||||
lastSyncTs = puppet.LastSync.Unix()
|
||||
lastSyncTS = puppet.LastSync.Unix()
|
||||
}
|
||||
_, err := puppet.db.Exec(`
|
||||
INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set,
|
||||
last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
|
||||
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
|
||||
return []any{
|
||||
puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname,
|
||||
puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTS,
|
||||
puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch,
|
||||
puppet.EnablePresence, puppet.EnableReceipts,
|
||||
)
|
||||
if err != nil {
|
||||
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Update() {
|
||||
var lastSyncTs int64
|
||||
if !puppet.LastSync.IsZero() {
|
||||
lastSyncTs = puppet.LastSync.Unix()
|
||||
func (puppet *Puppet) Insert(ctx context.Context) error {
|
||||
if puppet.JID.Server != types.DefaultUserServer {
|
||||
zerolog.Ctx(ctx).Warn().Stringer("jid", puppet.JID).Msg("Not inserting puppet: not a user")
|
||||
return nil
|
||||
}
|
||||
_, err := puppet.db.Exec(`
|
||||
UPDATE puppet
|
||||
SET displayname=$1, name_quality=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, contact_info_set=$7, last_sync=$8,
|
||||
custom_mxid=$9, access_token=$10, next_batch=$11, enable_presence=$12, enable_receipts=$13
|
||||
WHERE username=$14
|
||||
`, puppet.Displayname, puppet.NameQuality, puppet.NameSet, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.ContactInfoSet,
|
||||
lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts,
|
||||
puppet.JID.User)
|
||||
if err != nil {
|
||||
puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err)
|
||||
return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Update(ctx context.Context) error {
|
||||
return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,26 +17,20 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"context"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type ReactionQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*Reaction]
|
||||
}
|
||||
|
||||
func (rq *ReactionQuery) New() *Reaction {
|
||||
return &Reaction{
|
||||
db: rq.db,
|
||||
log: rq.log,
|
||||
}
|
||||
func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction {
|
||||
return &Reaction{qh: qh}
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -55,28 +49,20 @@ const (
|
|||
DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid
|
||||
`
|
||||
deleteReactionQuery = `
|
||||
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 AND mxid=$5
|
||||
DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4
|
||||
`
|
||||
)
|
||||
|
||||
func (rq *ReactionQuery) GetByTargetJID(chat PortalKey, jid types.MessageID, sender types.JID) *Reaction {
|
||||
return rq.maybeScan(rq.db.QueryRow(getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD()))
|
||||
func (rq *ReactionQuery) GetByTargetJID(ctx context.Context, chat PortalKey, jid types.MessageID, sender types.JID) (*Reaction, error) {
|
||||
return rq.QueryOne(ctx, getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())
|
||||
}
|
||||
|
||||
func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction {
|
||||
return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid))
|
||||
}
|
||||
|
||||
func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction {
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return rq.New().Scan(row)
|
||||
func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) {
|
||||
return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid)
|
||||
}
|
||||
|
||||
type Reaction struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*Reaction]
|
||||
|
||||
Chat PortalKey
|
||||
TargetJID types.MessageID
|
||||
|
@ -85,35 +71,19 @@ type Reaction struct {
|
|||
JID types.MessageID
|
||||
}
|
||||
|
||||
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
reaction.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return reaction
|
||||
func (reaction *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) {
|
||||
return dbutil.ValueOrErr(reaction, row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID))
|
||||
}
|
||||
|
||||
func (reaction *Reaction) Upsert(txn dbutil.Execable) {
|
||||
func (reaction *Reaction) sqlVariables() []any {
|
||||
reaction.Sender = reaction.Sender.ToNonAD()
|
||||
if txn == nil {
|
||||
txn = reaction.db
|
||||
}
|
||||
_, err := txn.Exec(upsertReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID)
|
||||
if err != nil {
|
||||
reaction.log.Warnfln("Failed to upsert reaction to %s@%s by %s: %v", reaction.Chat, reaction.TargetJID, reaction.Sender, err)
|
||||
}
|
||||
return []any{reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID}
|
||||
}
|
||||
|
||||
func (reaction *Reaction) GetTarget() *Message {
|
||||
return reaction.db.Message.GetByJID(reaction.Chat, reaction.TargetJID)
|
||||
func (reaction *Reaction) Upsert(ctx context.Context) error {
|
||||
return reaction.qh.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (reaction *Reaction) Delete() {
|
||||
_, err := reaction.db.Exec(deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID)
|
||||
if err != nil {
|
||||
reaction.log.Warnfln("Failed to delete reaction %s: %v", reaction.MXID, err)
|
||||
}
|
||||
func (reaction *Reaction) Delete(ctx context.Context) error {
|
||||
return reaction.qh.Exec(ctx, deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
|
||||
|
@ -29,7 +30,7 @@ var Table dbutil.UpgradeTable
|
|||
var rawUpgrades embed.FS
|
||||
|
||||
func init() {
|
||||
Table.Register(-1, 35, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
|
||||
Table.Register(-1, 35, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
|
||||
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
|
||||
})
|
||||
Table.RegisterFS(rawUpgrades)
|
||||
|
|
157
database/user.go
157
database/user.go
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,63 +17,65 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
type UserQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
*dbutil.QueryHelper[*User]
|
||||
}
|
||||
|
||||
func (uq *UserQuery) New() *User {
|
||||
func newUser(qh *dbutil.QueryHelper[*User]) *User {
|
||||
return &User{
|
||||
db: uq.db,
|
||||
log: uq.log,
|
||||
qh: qh,
|
||||
|
||||
lastReadCache: make(map[PortalKey]time.Time),
|
||||
inSpaceCache: make(map[PortalKey]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetAll() (users []*User) {
|
||||
rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`)
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
users = append(users, uq.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
const (
|
||||
getAllUsersQuery = `SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`
|
||||
getUserByMXIDQuery = getAllUsersQuery + ` WHERE mxid=$1`
|
||||
getUserByUsernameQuery = getAllUsersQuery + ` WHERE username=$1`
|
||||
insertUserQuery = `
|
||||
INSERT INTO "user" (
|
||||
mxid, username, agent, device,
|
||||
management_room, space_room,
|
||||
phone_last_seen, phone_last_pinged, timezone
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
updateUserQuery = `
|
||||
UPDATE "user"
|
||||
SET username=$1, agent=$2, device=$3,
|
||||
management_room=$4, space_room=$5,
|
||||
phone_last_seen=$6, phone_last_pinged=$7, timezone=$8
|
||||
WHERE mxid=$9
|
||||
`
|
||||
getUserLastAppStateKeyIDQuery = "SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1"
|
||||
)
|
||||
|
||||
func (uq *UserQuery) GetAll(ctx context.Context) ([]*User, error) {
|
||||
return uq.QueryMany(ctx, getAllUsersQuery)
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
|
||||
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return uq.New().Scan(row)
|
||||
func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) {
|
||||
return uq.QueryOne(ctx, getUserByMXIDQuery, userID)
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByUsername(username string) *User {
|
||||
row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
return uq.New().Scan(row)
|
||||
func (uq *UserQuery) GetByUsername(ctx context.Context, username string) (*User, error) {
|
||||
return uq.QueryOne(ctx, getUserByUsernameQuery, username)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
qh *dbutil.QueryHelper[*User]
|
||||
|
||||
MXID id.UserID
|
||||
JID types.JID
|
||||
|
@ -89,20 +91,21 @@ type User struct {
|
|||
inSpaceCacheLock sync.Mutex
|
||||
}
|
||||
|
||||
func (user *User) Scan(row dbutil.Scannable) *User {
|
||||
func (user *User) Scan(row dbutil.Scannable) (*User, error) {
|
||||
var username, timezone sql.NullString
|
||||
var device, agent sql.NullByte
|
||||
var device, agent sql.NullInt16
|
||||
var phoneLastSeen, phoneLastPinged sql.NullInt64
|
||||
err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
user.log.Errorln("Database scan failed:", err)
|
||||
}
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
user.Timezone = timezone.String
|
||||
if len(username.String) > 0 {
|
||||
user.JID = types.NewADJID(username.String, agent.Byte, device.Byte)
|
||||
user.JID = types.JID{
|
||||
User: username.String,
|
||||
Device: uint16(device.Int16),
|
||||
Server: types.DefaultUserServer,
|
||||
}
|
||||
}
|
||||
if phoneLastSeen.Valid {
|
||||
user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0)
|
||||
|
@ -110,66 +113,34 @@ func (user *User) Scan(row dbutil.Scannable) *User {
|
|||
if phoneLastPinged.Valid {
|
||||
user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0)
|
||||
}
|
||||
return user
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (user *User) usernamePtr() *string {
|
||||
func (user *User) sqlVariables() []any {
|
||||
var username *string
|
||||
var agent, device *uint16
|
||||
if !user.JID.IsEmpty() {
|
||||
return &user.JID.User
|
||||
username = dbutil.StrPtr(user.JID.User)
|
||||
var zero uint16
|
||||
agent = &zero
|
||||
device = dbutil.NumPtr(user.JID.Device)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) agentPtr() *uint8 {
|
||||
if !user.JID.IsEmpty() {
|
||||
zero := uint8(0)
|
||||
return &zero
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) devicePtr() *uint8 {
|
||||
if !user.JID.IsEmpty() {
|
||||
device := uint8(user.JID.Device)
|
||||
return &device
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) phoneLastSeenPtr() *int64 {
|
||||
if user.PhoneLastSeen.IsZero() {
|
||||
return nil
|
||||
}
|
||||
ts := user.PhoneLastSeen.Unix()
|
||||
return &ts
|
||||
}
|
||||
|
||||
func (user *User) phoneLastPingedPtr() *int64 {
|
||||
if user.PhoneLastPinged.IsZero() {
|
||||
return nil
|
||||
}
|
||||
ts := user.PhoneLastPinged.Unix()
|
||||
return &ts
|
||||
}
|
||||
|
||||
func (user *User) Insert() {
|
||||
_, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
|
||||
return []any{
|
||||
username, agent, device, user.ManagementRoom, user.SpaceRoom,
|
||||
dbutil.UnixPtr(user.PhoneLastSeen), dbutil.UnixPtr(user.PhoneLastPinged),
|
||||
user.Timezone, user.MXID,
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) Update() {
|
||||
_, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`,
|
||||
user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
|
||||
}
|
||||
func (user *User) Insert(ctx context.Context) error {
|
||||
return user.qh.Exec(ctx, insertUserQuery, user.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (user *User) GetLastAppStateKeyID() ([]byte, error) {
|
||||
var keyID []byte
|
||||
err := user.db.QueryRow("SELECT key_id FROM whatsmeow_app_state_sync_keys ORDER BY timestamp DESC LIMIT 1").Scan(&keyID)
|
||||
return keyID, err
|
||||
func (user *User) Update(ctx context.Context) error {
|
||||
return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (user *User) GetLastAppStateKeyID(ctx context.Context) (keyID []byte, err error) {
|
||||
err = user.qh.GetDB().QueryRow(ctx, getUserLastAppStateKeyIDQuery, user.JID).Scan(&keyID)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2021 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,69 +17,97 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func (user *User) GetLastReadTS(portal PortalKey) time.Time {
|
||||
const (
|
||||
getLastReadTSQuery = "SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
|
||||
setLastReadTSQuery = `
|
||||
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
|
||||
`
|
||||
getIsInSpaceQuery = "SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3"
|
||||
setIsInSpaceQuery = `
|
||||
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
|
||||
`
|
||||
)
|
||||
|
||||
func (user *User) GetLastReadTS(ctx context.Context, portal PortalKey) time.Time {
|
||||
user.lastReadCacheLock.Lock()
|
||||
defer user.lastReadCacheLock.Unlock()
|
||||
if cached, ok := user.lastReadCache[portal]; ok {
|
||||
return cached
|
||||
}
|
||||
var ts int64
|
||||
err := user.db.QueryRow("SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&ts)
|
||||
var parsedTS time.Time
|
||||
err := user.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, user.MXID, portal.JID, portal.Receiver).Scan(&ts)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
user.log.Warnfln("Failed to scan last read timestamp from user portal table: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Str("user_id", user.MXID.String()).
|
||||
Any("portal_key", portal).
|
||||
Msg("Failed to query last read timestamp")
|
||||
return parsedTS
|
||||
}
|
||||
if ts == 0 {
|
||||
user.lastReadCache[portal] = time.Time{}
|
||||
} else {
|
||||
user.lastReadCache[portal] = time.Unix(ts, 0)
|
||||
if ts != 0 {
|
||||
parsedTS = time.Unix(ts, 0)
|
||||
}
|
||||
user.lastReadCache[portal] = parsedTS
|
||||
return user.lastReadCache[portal]
|
||||
}
|
||||
|
||||
func (user *User) SetLastReadTS(portal PortalKey, ts time.Time) {
|
||||
func (user *User) SetLastReadTS(ctx context.Context, portal PortalKey, ts time.Time) {
|
||||
user.lastReadCacheLock.Lock()
|
||||
defer user.lastReadCacheLock.Unlock()
|
||||
_, err := user.db.Exec(`
|
||||
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts<excluded.last_read_ts
|
||||
`, user.MXID, portal.JID, portal.Receiver, ts.Unix())
|
||||
_, err := user.qh.GetDB().Exec(ctx, setLastReadTSQuery, user.MXID, portal.JID, portal.Receiver, ts.Unix())
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to update last read timestamp: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Str("user_id", user.MXID.String()).
|
||||
Any("portal_key", portal).
|
||||
Msg("Failed to update last read timestamp")
|
||||
} else {
|
||||
user.log.Debugfln("Set last read timestamp of %s in %s to %d", user.MXID, portal.String(), ts.Unix())
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Str("user_id", user.MXID.String()).
|
||||
Any("portal_key", portal).
|
||||
Time("last_read_ts", ts).
|
||||
Msg("Updated last read timestamp of portal")
|
||||
user.lastReadCache[portal] = ts
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) IsInSpace(portal PortalKey) bool {
|
||||
func (user *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
|
||||
user.inSpaceCacheLock.Lock()
|
||||
defer user.inSpaceCacheLock.Unlock()
|
||||
if cached, ok := user.inSpaceCache[portal]; ok {
|
||||
return cached
|
||||
}
|
||||
var inSpace bool
|
||||
err := user.db.QueryRow("SELECT in_space FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3", user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
|
||||
err := user.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver).Scan(&inSpace)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
user.log.Warnfln("Failed to scan in space status from user portal table: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Str("user_id", user.MXID.String()).
|
||||
Any("portal_key", portal).
|
||||
Msg("Failed to query in space status")
|
||||
return false
|
||||
}
|
||||
user.inSpaceCache[portal] = inSpace
|
||||
return inSpace
|
||||
}
|
||||
|
||||
func (user *User) MarkInSpace(portal PortalKey) {
|
||||
func (user *User) MarkInSpace(ctx context.Context, portal PortalKey) {
|
||||
user.inSpaceCacheLock.Lock()
|
||||
defer user.inSpaceCacheLock.Unlock()
|
||||
_, err := user.db.Exec(`
|
||||
INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, in_space) VALUES ($1, $2, $3, true)
|
||||
ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET in_space=true
|
||||
`, user.MXID, portal.JID, portal.Receiver)
|
||||
_, err := user.qh.GetDB().Exec(ctx, setIsInSpaceQuery, user.MXID, portal.JID, portal.Receiver)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to update in space status: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Str("user_id", user.MXID.String()).
|
||||
Any("portal_key", portal).
|
||||
Msg("Failed to update in space status")
|
||||
} else {
|
||||
user.inSpaceCache[portal] = true
|
||||
}
|
||||
|
|
58
disappear.go
58
disappear.go
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,10 +17,11 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -28,47 +29,74 @@ import (
|
|||
"maunium.net/go/mautrix-whatsapp/database"
|
||||
)
|
||||
|
||||
func (portal *Portal) MarkDisappearing(txn dbutil.Execable, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
|
||||
func (portal *Portal) MarkDisappearing(ctx context.Context, eventID id.EventID, expiresIn time.Duration, startsAt time.Time) {
|
||||
if expiresIn == 0 {
|
||||
return
|
||||
}
|
||||
expiresAt := startsAt.Add(expiresIn)
|
||||
|
||||
msg := portal.bridge.DB.DisappearingMessage.NewWithValues(portal.MXID, eventID, expiresIn, expiresAt)
|
||||
msg.Insert(txn)
|
||||
err := msg.Insert(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to insert disappearing message")
|
||||
}
|
||||
if expiresAt.Before(time.Now().Add(1 * time.Hour)) {
|
||||
go portal.sleepAndDelete(msg)
|
||||
go portal.sleepAndDelete(context.WithoutCancel(ctx), msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (br *WABridge) SleepAndDeleteUpcoming() {
|
||||
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
|
||||
func (br *WABridge) SleepAndDeleteUpcoming(ctx context.Context) {
|
||||
msgs, err := br.DB.DisappearingMessage.GetUpcomingScheduled(ctx, 1*time.Hour)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get upcoming disappearing messages")
|
||||
return
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
portal := br.GetPortalByMXID(msg.RoomID)
|
||||
if portal == nil {
|
||||
msg.Delete()
|
||||
err = msg.Delete(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("event_id", msg.EventID).
|
||||
Msg("Failed to delete disappearing message row with no portal")
|
||||
}
|
||||
} else {
|
||||
go portal.sleepAndDelete(msg)
|
||||
go portal.sleepAndDelete(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) sleepAndDelete(msg *database.DisappearingMessage) {
|
||||
func (portal *Portal) sleepAndDelete(ctx context.Context, msg *database.DisappearingMessage) {
|
||||
if _, alreadySleeping := portal.currentlySleepingToDelete.LoadOrStore(msg.EventID, true); alreadySleeping {
|
||||
return
|
||||
}
|
||||
defer portal.currentlySleepingToDelete.Delete(msg.EventID)
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
sleepTime := msg.ExpireAt.Sub(time.Now())
|
||||
portal.log.Debugfln("Sleeping for %s to make %s disappear", sleepTime, msg.EventID)
|
||||
log.Debug().
|
||||
Stringer("room_id", portal.MXID).
|
||||
Stringer("event_id", msg.EventID).
|
||||
Dur("sleep_time", sleepTime).
|
||||
Msg("Sleeping before making message disappear")
|
||||
time.Sleep(sleepTime)
|
||||
_, err := portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{
|
||||
_, err := portal.MainIntent().RedactEvent(ctx, msg.RoomID, msg.EventID, mautrix.ReqRedact{
|
||||
Reason: "Message expired",
|
||||
TxnID: fmt.Sprintf("mxwa_disappear_%s", msg.EventID),
|
||||
})
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to make %s disappear: %v", msg.EventID, err)
|
||||
log.Err(err).
|
||||
Stringer("room_id", portal.MXID).
|
||||
Stringer("event_id", msg.EventID).
|
||||
Msg("Failed to make event disappear")
|
||||
} else {
|
||||
portal.log.Debugfln("Disappeared %s", msg.EventID)
|
||||
log.Debug().
|
||||
Stringer("room_id", portal.MXID).
|
||||
Stringer("event_id", msg.EventID).
|
||||
Msg("Disappeared event")
|
||||
}
|
||||
err = msg.Delete(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete disapperaing message row in database after redacting event")
|
||||
}
|
||||
msg.Delete()
|
||||
}
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"golang.org/x/exp/slices"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -104,22 +106,27 @@ func NewFormatter(bridge *WABridge) *Formatter {
|
|||
return formatter
|
||||
}
|
||||
|
||||
func (formatter *Formatter) getMatrixInfoByJID(roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
|
||||
func (formatter *Formatter) getMatrixInfoByJID(ctx context.Context, roomID id.RoomID, jid types.JID) (mxid id.UserID, displayname string) {
|
||||
if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
|
||||
mxid = puppet.MXID
|
||||
displayname = puppet.Displayname
|
||||
}
|
||||
if user := formatter.bridge.GetUserByJID(jid); user != nil {
|
||||
mxid = user.MXID
|
||||
member := formatter.bridge.StateStore.GetMember(roomID, user.MXID)
|
||||
if len(member.Displayname) > 0 {
|
||||
member, err := formatter.bridge.StateStore.GetMember(ctx, roomID, user.MXID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("user_id", user.MXID).
|
||||
Msg("Failed to get member profile from state store")
|
||||
} else if len(member.Displayname) > 0 {
|
||||
displayname = member.Displayname
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
|
||||
func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) {
|
||||
output := html.EscapeString(content.Body)
|
||||
for regex, replacement := range formatter.waReplString {
|
||||
output = regex.ReplaceAllString(output, replacement)
|
||||
|
@ -145,7 +152,7 @@ func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.Messa
|
|||
// TODO lid support?
|
||||
continue
|
||||
}
|
||||
mxid, displayname := formatter.getMatrixInfoByJID(roomID, jid)
|
||||
mxid, displayname := formatter.getMatrixInfoByJID(ctx, roomID, jid)
|
||||
number := "@" + jid.User
|
||||
output = strings.ReplaceAll(output, number, fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname))
|
||||
content.Body = strings.ReplaceAll(content.Body, number, displayname)
|
||||
|
|
45
go.mod
45
go.mod
|
@ -1,25 +1,25 @@
|
|||
module maunium.net/go/mautrix-whatsapp
|
||||
|
||||
go 1.20
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.19
|
||||
github.com/prometheus/client_golang v1.17.0
|
||||
github.com/rs/zerolog v1.31.0
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/prometheus/client_golang v1.19.0
|
||||
github.com/rs/zerolog v1.32.0
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/tidwall/gjson v1.17.0
|
||||
go.mau.fi/util v0.2.1
|
||||
github.com/tidwall/gjson v1.17.1
|
||||
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e
|
||||
go.mau.fi/webp v0.1.0
|
||||
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735
|
||||
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611
|
||||
golang.org/x/image v0.14.0
|
||||
golang.org/x/net v0.19.0
|
||||
google.golang.org/protobuf v1.31.0
|
||||
maunium.net/go/maulogger/v2 v2.4.1
|
||||
maunium.net/go/mautrix v0.16.2
|
||||
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
|
||||
golang.org/x/image v0.15.0
|
||||
golang.org/x/net v0.22.0
|
||||
google.golang.org/protobuf v1.33.0
|
||||
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -27,24 +27,27 @@ require (
|
|||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
|
||||
github.com/prometheus/common v0.44.0 // indirect
|
||||
github.com/prometheus/procfs v0.11.1 // indirect
|
||||
github.com/prometheus/client_model v0.5.0 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/yuin/goldmark v1.6.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.0 // indirect
|
||||
go.mau.fi/libsignal v0.1.0 // indirect
|
||||
go.mau.fi/zeroconfig v0.1.2 // indirect
|
||||
golang.org/x/crypto v0.16.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
maunium.net/go/mauflag v1.0.0 // indirect
|
||||
)
|
||||
|
||||
//replace go.mau.fi/util => ../../Go/go-util
|
||||
//replace maunium.net/go/mautrix => ../mautrix-go
|
||||
|
|
100
go.sum
100
go.sum
|
@ -1,6 +1,9 @@
|
|||
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
|
||||
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c h1:WqjRVgUO039eiISCjsZC4F9onOEV93DJAk6v33rsZzY=
|
||||
github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c/go.mod h1:b9FFm9y4mEm36G8ytVmS1vkNzJa0KepmcdVY+qf7qRU=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
|
@ -9,18 +12,18 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
|
|||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
|
@ -30,78 +33,75 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
|
||||
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
|
||||
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
|
||||
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
|
||||
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
|
||||
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
|
||||
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
|
||||
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
|
||||
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
|
||||
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
|
||||
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
|
||||
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
|
||||
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
|
||||
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
|
||||
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
|
||||
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
|
||||
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
|
||||
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c=
|
||||
go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I=
|
||||
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw=
|
||||
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c=
|
||||
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e h1:e1jDj/MjleSS5r9DMRbuCZYKy5Rr+sbsu8eWjtLqrGk=
|
||||
go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo=
|
||||
go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA=
|
||||
go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8=
|
||||
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 h1:+teJYCOK6M4Kn2TYCj29levhHVwnJTmgCtEXLtgwQtM=
|
||||
go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735/go.mod h1:5xTtHNaZpGni6z6aE1iEopjW7wNgsKcolZxZrOujK9M=
|
||||
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 h1:QOGjCIh2WEfkgX/38KLjnNof79GWx0T+KLrhTHiws3s=
|
||||
go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462/go.mod h1:lQHbhaG/fI+6hfGqz5Vzn2OBJBEZ05H0kCP6iJXriN4=
|
||||
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
|
||||
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
|
||||
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 h1:qCEDpW1G+vcj3Y7Fy52pEM1AWm3abj8WimGYejI3SC4=
|
||||
golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
|
||||
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
|
||||
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
|
||||
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
|
||||
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=
|
||||
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa h1:TLSWIAWKIWxLghgzWfp7o92pVCcFR3yLsArc0s/tsMs=
|
||||
maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa/go.mod h1:0sfLB2ejW+lhgio4UlZMmn5i9SuZ8mxFkonFSamrfTE=
|
||||
|
|
415
historysync.go
415
historysync.go
|
@ -17,6 +17,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
@ -24,11 +25,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/variationselector"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"go.mau.fi/util/variationselector"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -64,9 +65,8 @@ func (user *User) handleHistorySyncsLoop() {
|
|||
if batchSend {
|
||||
// Start the backfill queue.
|
||||
user.BackfillQueue = &BackfillQueue{
|
||||
BackfillQuery: user.bridge.DB.Backfill,
|
||||
BackfillQuery: user.bridge.DB.BackfillQueue,
|
||||
reCheckChannels: []chan bool{},
|
||||
log: user.log.Sub("BackfillQueue"),
|
||||
}
|
||||
|
||||
forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward}
|
||||
|
@ -109,33 +109,52 @@ func (user *User) handleHistorySyncsLoop() {
|
|||
const EnqueueBackfillsDelay = 30 * time.Second
|
||||
|
||||
func (user *User) enqueueAllBackfills() {
|
||||
nMostRecent := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
|
||||
if len(nMostRecent) > 0 {
|
||||
user.log.Infofln("%v has passed since the last history sync blob, enqueueing backfills for %d chats", EnqueueBackfillsDelay, len(nMostRecent))
|
||||
log := user.zlog.With().
|
||||
Str("method", "User.enqueueAllBackfills").
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
nMostRecent, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get recent history sync conversations from database")
|
||||
return
|
||||
} else if len(nMostRecent) == 0 {
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
Int("chat_count", len(nMostRecent)).
|
||||
Msg("Enqueueing backfills for recent chats in history sync")
|
||||
// Find the portals for all the conversations.
|
||||
portals := []*Portal{}
|
||||
portals := make([]*Portal, 0, len(nMostRecent))
|
||||
for _, conv := range nMostRecent {
|
||||
jid, err := types.ParseJID(conv.ConversationID)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to parse chat JID '%s' in history sync: %v", conv.ConversationID, err)
|
||||
log.Err(err).Str("conversation_id", conv.ConversationID).Msg("Failed to parse chat JID in history sync")
|
||||
continue
|
||||
}
|
||||
portals = append(portals, user.GetPortalByJID(jid))
|
||||
}
|
||||
|
||||
user.EnqueueImmediateBackfills(portals)
|
||||
user.EnqueueForwardBackfills(portals)
|
||||
user.EnqueueDeferredBackfills(portals)
|
||||
user.EnqueueImmediateBackfills(ctx, portals)
|
||||
user.EnqueueForwardBackfills(ctx, portals)
|
||||
user.EnqueueDeferredBackfills(ctx, portals)
|
||||
|
||||
// Tell the queue to check for new backfill requests.
|
||||
user.BackfillQueue.ReCheck()
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) backfillAll() {
|
||||
conversations := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, -1)
|
||||
if len(conversations) > 0 {
|
||||
user.zlog.Info().
|
||||
log := user.zlog.With().
|
||||
Str("method", "User.backfillAll").
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
conversations, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, -1)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get history sync conversations from database")
|
||||
return
|
||||
} else if len(conversations) == 0 {
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
Int("conversation_count", len(conversations)).
|
||||
Msg("Probably received all history sync blobs, now backfilling conversations")
|
||||
limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations
|
||||
|
@ -144,45 +163,59 @@ func (user *User) backfillAll() {
|
|||
for _, conv := range conversations {
|
||||
jid, err := types.ParseJID(conv.ConversationID)
|
||||
if err != nil {
|
||||
user.zlog.Warn().Err(err).
|
||||
log.Err(err).
|
||||
Str("conversation_id", conv.ConversationID).
|
||||
Msg("Failed to parse chat JID in history sync")
|
||||
continue
|
||||
}
|
||||
portal := user.GetPortalByJID(jid)
|
||||
if portal.MXID != "" {
|
||||
user.zlog.Debug().
|
||||
log.Debug().
|
||||
Str("portal_jid", portal.Key.JID.String()).
|
||||
Msg("Chat already has a room, deleting messages from database")
|
||||
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
|
||||
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
|
||||
if err != nil {
|
||||
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
|
||||
Msg("Failed to delete history sync conversation with existing portal from database")
|
||||
}
|
||||
bridgedCount++
|
||||
} else if !user.bridge.DB.HistorySync.ConversationHasMessages(user.MXID, portal.Key) {
|
||||
user.zlog.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
|
||||
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
|
||||
} else if hasMessages, err := user.bridge.DB.HistorySync.ConversationHasMessages(ctx, user.MXID, portal.Key); err != nil {
|
||||
log.Err(err).Str("portal_jid", portal.Key.JID.String()).Msg("Failed to check if chat has messages in history sync")
|
||||
} else if !hasMessages {
|
||||
log.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync")
|
||||
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
|
||||
if err != nil {
|
||||
log.Err(err).Str("portal_jid", portal.Key.JID.String()).
|
||||
Msg("Failed to delete history sync conversation with no messages from database")
|
||||
}
|
||||
} else if limit < 0 || bridgedCount < limit {
|
||||
bridgedCount++
|
||||
err = portal.CreateMatrixRoom(user, nil, nil, true, true)
|
||||
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, true)
|
||||
if err != nil {
|
||||
user.zlog.Err(err).Msg("Failed to create Matrix room for backfill")
|
||||
}
|
||||
log.Err(err).Msg("Failed to create Matrix room for backfill")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) legacyBackfill(user *User) {
|
||||
func (portal *Portal) legacyBackfill(ctx context.Context, user *User) {
|
||||
defer portal.latestEventBackfillLock.Unlock()
|
||||
// This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room.
|
||||
if portal.latestEventBackfillLock.TryLock() {
|
||||
panic("legacyBackfill() called without locking latestEventBackfillLock")
|
||||
}
|
||||
// TODO use portal.zlog instead of user.zlog
|
||||
log := user.zlog.With().
|
||||
Str("portal_jid", portal.Key.JID.String()).
|
||||
Str("action", "legacy backfill").
|
||||
Logger()
|
||||
conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, portal.Key)
|
||||
messages := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "legacy backfill").Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, portal.Key)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get history sync conversation data for backfill")
|
||||
return
|
||||
}
|
||||
messages, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get history sync messages for backfill")
|
||||
return
|
||||
}
|
||||
log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database")
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i])
|
||||
|
@ -194,18 +227,26 @@ func (portal *Portal) legacyBackfill(user *User) {
|
|||
Msg("Dropping historical message due to parse error")
|
||||
continue
|
||||
}
|
||||
portal.handleMessage(user, msgEvt, true)
|
||||
ctx := log.With().
|
||||
Str("message_id", msgEvt.Info.ID).
|
||||
Stringer("message_sender", msgEvt.Info.Sender).
|
||||
Logger().
|
||||
WithContext(ctx)
|
||||
portal.handleMessage(ctx, user, msgEvt, true)
|
||||
}
|
||||
if conv != nil {
|
||||
isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0
|
||||
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
|
||||
shouldMarkAsRead := !isUnread || isTooOld
|
||||
if shouldMarkAsRead {
|
||||
user.markSelfReadFull(portal)
|
||||
user.markSelfReadFull(ctx, portal)
|
||||
}
|
||||
}
|
||||
log.Debug().Msg("Backfill complete, deleting leftover messages from database")
|
||||
user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String())
|
||||
log.Info().Msg("Backfill complete, deleting leftover messages from database")
|
||||
err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String())
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete history sync conversation from database after backfill")
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) dailyMediaRequestLoop() {
|
||||
|
@ -224,29 +265,49 @@ func (user *User) dailyMediaRequestLoop() {
|
|||
if requestStartTime.Before(now) {
|
||||
requestStartTime = requestStartTime.AddDate(0, 0, 1)
|
||||
}
|
||||
log := user.zlog.With().
|
||||
Str("action", "daily media request loop").
|
||||
Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
|
||||
// Wait to start the loop
|
||||
user.log.Infof("Waiting until %s to do media retry requests", requestStartTime)
|
||||
log.Info().Time("start_loop_at", requestStartTime).Msg("Waiting until start time to do media retry requests")
|
||||
time.Sleep(time.Until(requestStartTime))
|
||||
|
||||
for {
|
||||
mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID)
|
||||
user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests))
|
||||
mediaBackfillRequests, err := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(ctx, user.MXID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get media retry requests")
|
||||
} else if len(mediaBackfillRequests) > 0 {
|
||||
log.Info().Int("media_request_count", len(mediaBackfillRequests)).Msg("Sending media retry requests")
|
||||
|
||||
// Send all of the media backfill requests for the user at once
|
||||
// Send all the media backfill requests for the user at once
|
||||
for _, req := range mediaBackfillRequests {
|
||||
portal := user.GetPortalByJID(req.PortalKey.JID)
|
||||
_, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey)
|
||||
_, err = portal.requestMediaRetry(ctx, user, req.EventID, req.MediaKey)
|
||||
if err != nil {
|
||||
user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID)
|
||||
log.Err(err).
|
||||
Stringer("portal_key", req.PortalKey).
|
||||
Stringer("event_id", req.EventID).
|
||||
Msg("Failed to send media retry request")
|
||||
req.Status = database.MediaBackfillRequestStatusRequestFailed
|
||||
req.Error = err.Error()
|
||||
} else {
|
||||
user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID)
|
||||
log.Debug().
|
||||
Stringer("portal_key", req.PortalKey).
|
||||
Stringer("event_id", req.EventID).
|
||||
Msg("Sent media retry request")
|
||||
req.Status = database.MediaBackfillRequestStatusRequested
|
||||
}
|
||||
req.MediaKey = nil
|
||||
req.Upsert()
|
||||
err = req.Upsert(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("portal_key", req.PortalKey).
|
||||
Stringer("event_id", req.EventID).
|
||||
Msg("Failed to save status of media retry request")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for 24 hours before making requests again
|
||||
|
@ -254,20 +315,29 @@ func (user *User) dailyMediaRequestLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) {
|
||||
func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTask, conv *database.HistorySyncConversation, portal *Portal) {
|
||||
portal.backfillLock.Lock()
|
||||
defer portal.backfillLock.Unlock()
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(portal.MXID, user.MXID) {
|
||||
portal.ensureUserInvited(user)
|
||||
if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(ctx, portal.MXID, user.MXID) {
|
||||
portal.ensureUserInvited(ctx, user)
|
||||
}
|
||||
|
||||
backfillState := user.bridge.DB.Backfill.GetBackfillState(user.MXID, &portal.Key)
|
||||
backfillState, err := user.bridge.DB.BackfillState.GetBackfillState(ctx, user.MXID, portal.Key)
|
||||
if backfillState == nil {
|
||||
backfillState = user.bridge.DB.Backfill.NewBackfillState(user.MXID, &portal.Key)
|
||||
backfillState = user.bridge.DB.BackfillState.NewBackfillState(user.MXID, portal.Key)
|
||||
}
|
||||
backfillState.SetProcessingBatch(true)
|
||||
defer backfillState.SetProcessingBatch(false)
|
||||
err = backfillState.SetProcessingBatch(ctx, true)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to mark batch as being processed")
|
||||
}
|
||||
defer func() {
|
||||
err = backfillState.SetProcessingBatch(ctx, false)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to mark batch as no longer being processed")
|
||||
}
|
||||
}()
|
||||
|
||||
var timeEnd *time.Time
|
||||
var forward, shouldMarkAsRead bool
|
||||
|
@ -275,17 +345,27 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
|
|||
if req.BackfillType == database.BackfillForward {
|
||||
// TODO this overrides the TimeStart set when enqueuing the backfill
|
||||
// maybe the enqueue should instead include the prev event ID
|
||||
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
|
||||
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get newest message in chat")
|
||||
return
|
||||
}
|
||||
start := lastMessage.Timestamp.Add(1 * time.Second)
|
||||
req.TimeStart = &start
|
||||
// Sending events at the end of the room (= latest events)
|
||||
forward = true
|
||||
} else {
|
||||
firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key)
|
||||
firstMessage, err := portal.bridge.DB.Message.GetFirstInChat(ctx, portal.Key)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get oldest message in chat")
|
||||
return
|
||||
}
|
||||
if firstMessage != nil {
|
||||
end := firstMessage.Timestamp.Add(-1 * time.Second)
|
||||
timeEnd = &end
|
||||
user.log.Debugfln("Limiting backfill to end at %v", end)
|
||||
log.Debug().
|
||||
Time("oldest_message_ts", firstMessage.Timestamp).
|
||||
Msg("Limiting backfill to messages older than oldest message")
|
||||
} else {
|
||||
// Portal is empty -> events are latest
|
||||
forward = true
|
||||
|
@ -303,45 +383,48 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
|
|||
isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour))
|
||||
shouldMarkAsRead = !isUnread || isTooOld
|
||||
}
|
||||
allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
|
||||
allMsgs, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents)
|
||||
|
||||
sendDisappearedNotice := false
|
||||
// If expired messages are on, and a notice has not been sent to this chat
|
||||
// about it having disappeared messages at the conversation timestamp, send
|
||||
// a notice indicating so.
|
||||
if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 {
|
||||
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
|
||||
lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get last message in chat to check if disappeared notice should be sent")
|
||||
}
|
||||
if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) {
|
||||
sendDisappearedNotice = true
|
||||
}
|
||||
}
|
||||
|
||||
if !sendDisappearedNotice && len(allMsgs) == 0 {
|
||||
user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID)
|
||||
log.Debug().Msg("Not backfilling chat: no bridgeable messages found")
|
||||
return
|
||||
}
|
||||
|
||||
if len(portal.MXID) == 0 {
|
||||
user.log.Debugln("Creating portal for", portal.Key.JID, "as part of history sync handling")
|
||||
err := portal.CreateMatrixRoom(user, nil, nil, true, false)
|
||||
log.Debug().Msg("Creating portal for chat as part of history sync handling")
|
||||
err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, false)
|
||||
if err != nil {
|
||||
user.log.Errorfln("Failed to create room for %s during backfill: %v", portal.Key.JID, err)
|
||||
log.Err(err).Msg("Failed to create room for chat during backfill")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update the backfill status here after the room has been created.
|
||||
portal.updateBackfillStatus(backfillState)
|
||||
portal.updateBackfillStatus(ctx, backfillState)
|
||||
|
||||
if sendDisappearedNotice {
|
||||
user.log.Debugfln("Sending notice to %s that there are disappeared messages ending at %v", portal.Key.JID, conv.LastMessageTimestamp)
|
||||
resp, err := portal.sendMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
|
||||
log.Debug().Time("last_message_time", conv.LastMessageTimestamp).
|
||||
Msg("Sending notice that there are disappeared messages in the chat")
|
||||
resp, err := portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
|
||||
MsgType: event.MsgNotice,
|
||||
Body: portal.formatDisappearingMessageNotice(),
|
||||
}, nil, conv.LastMessageTimestamp.UnixMilli())
|
||||
|
||||
if err != nil {
|
||||
portal.log.Errorln("Error sending disappearing messages notice event")
|
||||
log.Err(err).Msg("Failed to send disappeared messages notice event")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -353,12 +436,18 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
|
|||
msg.SenderMXID = portal.MainIntent().UserID
|
||||
msg.Sent = true
|
||||
msg.Type = database.MsgFake
|
||||
msg.Insert(nil)
|
||||
user.markSelfReadFull(portal)
|
||||
err = msg.Insert(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save fake message entry for disappearing message timer in backfill")
|
||||
}
|
||||
user.markSelfReadFull(ctx, portal)
|
||||
return
|
||||
}
|
||||
|
||||
user.log.Infofln("Backfilling %d messages in %s, %d messages at a time (queue ID: %d)", len(allMsgs), portal.Key.JID, req.MaxBatchEvents, req.QueueID)
|
||||
log.Info().
|
||||
Int("message_count", len(allMsgs)).
|
||||
Int("max_batch_events", req.MaxBatchEvents).
|
||||
Msg("Backfilling messages")
|
||||
toBackfill := allMsgs[0:]
|
||||
for len(toBackfill) > 0 {
|
||||
var msgs []*waProto.WebMessageInfo
|
||||
|
@ -372,14 +461,14 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
|
|||
|
||||
if len(msgs) > 0 {
|
||||
time.Sleep(time.Duration(req.BatchDelay) * time.Second)
|
||||
user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID)
|
||||
portal.backfill(user, msgs, forward, shouldMarkAsRead)
|
||||
log.Debug().Int("batch_message_count", len(msgs)).Msg("Backfilling message batch")
|
||||
portal.backfill(ctx, user, msgs, forward, shouldMarkAsRead)
|
||||
}
|
||||
}
|
||||
user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID)
|
||||
err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs)
|
||||
log.Debug().Int("message_count", len(allMsgs)).Msg("Finished backfilling messages in queue entry")
|
||||
err = user.bridge.DB.HistorySync.DeleteMessages(ctx, user.MXID, conv.ConversationID, allMsgs)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err)
|
||||
log.Err(err).Msg("Failed to delete history sync messages after backfilling")
|
||||
}
|
||||
|
||||
if req.TimeStart == nil {
|
||||
|
@ -399,8 +488,11 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor
|
|||
// beginning of time.
|
||||
backfillState.FirstExpectedTimestamp = 0
|
||||
}
|
||||
backfillState.Upsert()
|
||||
portal.updateBackfillStatus(backfillState)
|
||||
err = backfillState.Upsert(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to mark backfill state as completed in database")
|
||||
}
|
||||
portal.updateBackfillStatus(ctx, backfillState)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,13 +500,13 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
|
|||
if evt == nil || evt.SyncType == nil {
|
||||
return
|
||||
}
|
||||
log := user.bridge.ZLog.With().
|
||||
log := user.zlog.With().
|
||||
Str("method", "User.storeHistorySync").
|
||||
Str("user_id", user.MXID.String()).
|
||||
Str("sync_type", evt.GetSyncType().String()).
|
||||
Uint32("chunk_order", evt.GetChunkOrder()).
|
||||
Uint32("progress", evt.GetProgress()).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
if evt.GetGlobalSettings() != nil {
|
||||
log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync")
|
||||
}
|
||||
|
@ -466,7 +558,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
|
|||
historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues(
|
||||
user.MXID,
|
||||
conv.GetId(),
|
||||
&portal.Key,
|
||||
portal.Key,
|
||||
getConversationTimestamp(conv),
|
||||
conv.GetMuteEndTime(),
|
||||
conv.GetArchived(),
|
||||
|
@ -476,7 +568,10 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
|
|||
conv.EphemeralExpiration,
|
||||
conv.GetMarkedAsUnread(),
|
||||
conv.GetUnreadCount())
|
||||
historySyncConversation.Upsert()
|
||||
err := historySyncConversation.Upsert(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to insert history sync conversation into database")
|
||||
}
|
||||
}
|
||||
|
||||
var minTime, maxTime time.Time
|
||||
|
@ -521,7 +616,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) {
|
|||
Msg("Failed to save historical message")
|
||||
continue
|
||||
}
|
||||
err = message.Insert()
|
||||
err = message.Insert(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Int("msg_index", i).
|
||||
|
@ -570,15 +665,20 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 {
|
|||
return convTs
|
||||
}
|
||||
|
||||
func (user *User) EnqueueImmediateBackfills(portals []*Portal) {
|
||||
func (user *User) EnqueueImmediateBackfills(ctx context.Context, portals []*Portal) {
|
||||
for priority, portal := range portals {
|
||||
maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents
|
||||
initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0)
|
||||
initialBackfill.Insert()
|
||||
initialBackfill := user.bridge.DB.BackfillQueue.NewWithValues(user.MXID, database.BackfillImmediate, priority, portal.Key, nil, maxMessages, maxMessages, 0)
|
||||
err := initialBackfill.Insert(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("portal_key", portal.Key).
|
||||
Msg("Failed to insert immediate backfill into database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
|
||||
func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Portal) {
|
||||
numPortals := len(portals)
|
||||
for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred {
|
||||
for portalIdx, portal := range portals {
|
||||
|
@ -587,22 +687,36 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) {
|
|||
startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo)
|
||||
startDate = &startDaysAgo
|
||||
}
|
||||
backfillMessages := user.bridge.DB.Backfill.NewWithValues(
|
||||
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
|
||||
backfillMessages.Insert()
|
||||
backfillMessages := user.bridge.DB.BackfillQueue.NewWithValues(
|
||||
user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay)
|
||||
err := backfillMessages.Insert(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("portal_key", portal.Key).
|
||||
Msg("Failed to insert deferred backfill into database")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) EnqueueForwardBackfills(portals []*Portal) {
|
||||
func (user *User) EnqueueForwardBackfills(ctx context.Context, portals []*Portal) {
|
||||
for priority, portal := range portals {
|
||||
lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key)
|
||||
if lastMsg == nil {
|
||||
lastMsg, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("portal_key", portal.Key).
|
||||
Msg("Failed to get last message in chat to enqueue forward backfill")
|
||||
} else if lastMsg == nil {
|
||||
continue
|
||||
}
|
||||
backfill := user.bridge.DB.Backfill.NewWithValues(
|
||||
user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0)
|
||||
backfill.Insert()
|
||||
backfill := user.bridge.DB.BackfillQueue.NewWithValues(
|
||||
user.MXID, database.BackfillForward, priority, portal.Key, &lastMsg.Timestamp, -1, -1, 0)
|
||||
err = backfill.Insert(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("portal_key", portal.Key).
|
||||
Msg("Failed to insert forward backfill into database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -619,12 +733,11 @@ func (portal *Portal) deterministicEventID(sender types.JID, messageID types.Mes
|
|||
}
|
||||
|
||||
var (
|
||||
PortalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
|
||||
|
||||
BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType}
|
||||
)
|
||||
|
||||
func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) *mautrix.RespBeeperBatchSend {
|
||||
func (portal *Portal) backfill(ctx context.Context, source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
var req mautrix.ReqBeeperBatchSend
|
||||
var infos []*wrappedInfo
|
||||
|
||||
|
@ -633,7 +746,10 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
|
|||
req.MarkReadBy = source.MXID
|
||||
}
|
||||
|
||||
portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward)
|
||||
log.Info().
|
||||
Bool("forward", isForward).
|
||||
Int("message_count", len(messages)).
|
||||
Msg("Processing history sync message batch")
|
||||
// The messages are ordered newest to oldest, so iterate them in reverse order.
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
webMsg := messages[i]
|
||||
|
@ -641,11 +757,16 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
log := log.With().
|
||||
Str("message_id", msgEvt.Info.ID).
|
||||
Stringer("message_sender", msgEvt.Info.Sender).
|
||||
Logger()
|
||||
ctx := log.WithContext(ctx)
|
||||
|
||||
msgType := getMessageType(msgEvt.Message)
|
||||
if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" {
|
||||
if msgType != "ignore" {
|
||||
portal.log.Debugfln("Skipping message %s with unknown type in backfill", msgEvt.Info.ID)
|
||||
log.Debug().Msg("Skipping message with unknown type in backfill")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@ -654,85 +775,83 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo,
|
|||
if !existingContact.Found || existingContact.PushName == "" {
|
||||
changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName())
|
||||
if err != nil {
|
||||
source.log.Errorfln("Failed to save push name of %s from historical message in device store: %v", msgEvt.Info.Sender, err)
|
||||
log.Err(err).Msg("Failed to save push name from historical message to device store")
|
||||
} else if changed {
|
||||
source.log.Debugfln("Got push name %s for %s from historical message", webMsg.GetPushName(), msgEvt.Info.Sender)
|
||||
log.Debug().Str("push_name", webMsg.GetPushName()).Msg("Got push name from historical message")
|
||||
}
|
||||
}
|
||||
}
|
||||
puppet := portal.getMessagePuppet(source, &msgEvt.Info)
|
||||
puppet := portal.getMessagePuppet(ctx, source, &msgEvt.Info)
|
||||
if puppet == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
converted := portal.convertMessage(puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
|
||||
converted := portal.convertMessage(ctx, puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true)
|
||||
if converted == nil {
|
||||
portal.log.Debugfln("Skipping unsupported message %s in backfill", msgEvt.Info.ID)
|
||||
log.Debug().Msg("Skipping unsupported message in backfill")
|
||||
continue
|
||||
}
|
||||
if converted.ReplyTo != nil {
|
||||
portal.SetReply(msgEvt.Info.ID, converted.Content, converted.ReplyTo, true)
|
||||
portal.SetReply(ctx, converted.Content, converted.ReplyTo, true)
|
||||
}
|
||||
err = portal.appendBatchEvents(source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
|
||||
err = portal.appendBatchEvents(ctx, source, converted, &msgEvt.Info, webMsg, &req.Events, &infos)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Error handling message %s during backfill: %v", msgEvt.Info.ID, err)
|
||||
log.Err(err).Msg("Failed to handle message in backfill")
|
||||
}
|
||||
}
|
||||
portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events))
|
||||
log.Info().Int("event_count", len(req.Events)).Msg("Made Matrix events from messages in batch")
|
||||
|
||||
if len(req.Events) == 0 {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &req)
|
||||
resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &req)
|
||||
if err != nil {
|
||||
portal.log.Errorln("Error batch sending messages:", err)
|
||||
return nil
|
||||
} else {
|
||||
txn, err := portal.bridge.DB.Begin()
|
||||
if err != nil {
|
||||
portal.log.Errorln("Failed to start transaction to save batch messages:", err)
|
||||
return nil
|
||||
log.Err(err).Msg("Failed to send batch of messages")
|
||||
return
|
||||
}
|
||||
|
||||
portal.finishBatch(txn, resp.EventIDs, infos)
|
||||
|
||||
err = txn.Commit()
|
||||
err = portal.bridge.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
return portal.finishBatch(ctx, resp.EventIDs, infos)
|
||||
})
|
||||
if err != nil {
|
||||
portal.log.Errorln("Failed to commit transaction to save batch messages:", err)
|
||||
return nil
|
||||
log.Err(err).Msg("Failed to save message batch to database")
|
||||
return
|
||||
}
|
||||
log.Info().Msg("Successfully sent backfill batch")
|
||||
if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia {
|
||||
go portal.requestMediaRetries(source, resp.EventIDs, infos)
|
||||
}
|
||||
return resp
|
||||
go portal.requestMediaRetries(context.TODO(), source, resp.EventIDs, infos)
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
|
||||
func (portal *Portal) requestMediaRetries(ctx context.Context, source *User, eventIDs []id.EventID, infos []*wrappedInfo) {
|
||||
for i, info := range infos {
|
||||
if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil {
|
||||
switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod {
|
||||
case config.MediaRequestMethodImmediate:
|
||||
err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to send post-backfill media retry request for %s: %v", info.ID, err)
|
||||
portal.zlog.Err(err).Str("message_id", info.ID).Msg("Failed to send post-backfill media retry request")
|
||||
} else {
|
||||
portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID)
|
||||
portal.zlog.Debug().Str("message_id", info.ID).Msg("Sent post-backfill media retry request")
|
||||
}
|
||||
case config.MediaRequestMethodLocalTime:
|
||||
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey)
|
||||
req.Upsert()
|
||||
req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, portal.Key, eventIDs[i], info.MediaKey)
|
||||
err := req.Upsert(ctx)
|
||||
if err != nil {
|
||||
portal.zlog.Err(err).
|
||||
Stringer("event_id", eventIDs[i]).
|
||||
Msg("Failed to upsert media backfill request")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
|
||||
func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error {
|
||||
if portal.bridge.Config.Bridge.CaptionInMessage {
|
||||
converted.MergeCaption()
|
||||
}
|
||||
mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
|
||||
mainEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Content, converted.Extra, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -750,7 +869,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
|
|||
ExpiresIn: converted.ExpiresIn,
|
||||
}
|
||||
if converted.Caption != nil {
|
||||
captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
|
||||
captionEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Caption, nil, "caption")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -762,7 +881,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
|
|||
}
|
||||
if converted.MultiEvent != nil {
|
||||
for i, subEvtContent := range converted.MultiEvent {
|
||||
subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
|
||||
subEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -771,7 +890,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
|
|||
}
|
||||
}
|
||||
for _, reaction := range raw.GetReactions() {
|
||||
reactionEvent, reactionInfo := portal.wrapBatchReaction(source, reaction, mainEvt.ID, info.Timestamp)
|
||||
reactionEvent, reactionInfo := portal.wrapBatchReaction(ctx, source, reaction, mainEvt.ID, info.Timestamp)
|
||||
if reactionEvent != nil {
|
||||
*eventsArray = append(*eventsArray, reactionEvent)
|
||||
*infoArray = append(*infoArray, &wrappedInfo{
|
||||
|
@ -785,7 +904,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag
|
|||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
|
||||
func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) {
|
||||
var senderJID types.JID
|
||||
if reaction.GetKey().GetFromMe() {
|
||||
senderJID = source.JID.ToNonAD()
|
||||
|
@ -807,7 +926,7 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
|
|||
ID: reaction.GetKey().GetId(),
|
||||
Timestamp: mainEventTS,
|
||||
}
|
||||
puppet := portal.getMessagePuppet(source, reactionInfo)
|
||||
puppet := portal.getMessagePuppet(ctx, source, reactionInfo)
|
||||
if puppet == nil {
|
||||
return
|
||||
}
|
||||
|
@ -834,12 +953,12 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction
|
|||
return
|
||||
}
|
||||
|
||||
func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
|
||||
func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) {
|
||||
wrappedContent := event.Content{
|
||||
Parsed: content,
|
||||
Raw: extraContent,
|
||||
}
|
||||
newEventType, err := portal.encrypt(intent, &wrappedContent, eventType)
|
||||
newEventType, err := portal.encrypt(ctx, intent, &wrappedContent, eventType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -853,37 +972,37 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, infos []*wrappedInfo) {
|
||||
func (portal *Portal) finishBatch(ctx context.Context, eventIDs []id.EventID, infos []*wrappedInfo) error {
|
||||
for i, info := range infos {
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventID := eventIDs[i]
|
||||
portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
|
||||
portal.markHandled(ctx, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error)
|
||||
if info.Type == database.MsgReaction {
|
||||
portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
|
||||
portal.upsertReaction(ctx, nil, info.ReactionTarget, info.Sender, eventID, info.ID)
|
||||
}
|
||||
|
||||
if info.ExpiresIn > 0 {
|
||||
portal.MarkDisappearing(txn, eventID, info.ExpiresIn, info.ExpirationStart)
|
||||
portal.MarkDisappearing(ctx, eventID, info.ExpiresIn, info.ExpirationStart)
|
||||
}
|
||||
}
|
||||
portal.log.Infofln("Successfully sent %d events", len(eventIDs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) updateBackfillStatus(backfillState *database.BackfillState) {
|
||||
func (portal *Portal) updateBackfillStatus(ctx context.Context, backfillState *database.BackfillState) {
|
||||
backfillStatus := "backfilling"
|
||||
if backfillState.BackfillComplete {
|
||||
backfillStatus = "complete"
|
||||
}
|
||||
|
||||
_, err := portal.bridge.Bot.SendStateEvent(portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
|
||||
_, err := portal.bridge.Bot.SendStateEvent(ctx, portal.MXID, BackfillStatusEvent, "", map[string]interface{}{
|
||||
"status": backfillStatus,
|
||||
"first_timestamp": backfillState.FirstExpectedTimestamp * 1000,
|
||||
})
|
||||
if err != nil {
|
||||
portal.log.Errorln("Error sending backfill status event:", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill status event to room")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
62
main.go
62
main.go
|
@ -17,6 +17,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -26,15 +27,18 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
waLog "go.mau.fi/whatsmeow/util/log"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"go.mau.fi/util/configupgrade"
|
||||
"go.mau.fi/whatsmeow"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
"go.mau.fi/whatsmeow/store"
|
||||
"go.mau.fi/whatsmeow/store/sqlstore"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"go.mau.fi/util/configupgrade"
|
||||
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/bridge/commands"
|
||||
"maunium.net/go/mautrix/bridge/status"
|
||||
|
@ -91,7 +95,7 @@ func (br *WABridge) Init() {
|
|||
br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage)
|
||||
br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage)
|
||||
|
||||
Analytics.log = br.Log.Sub("Analytics")
|
||||
Analytics.log = br.ZLog.With().Str("component", "analytics").Logger()
|
||||
Analytics.url = (&url.URL{
|
||||
Scheme: "https",
|
||||
Host: br.Config.Analytics.Host,
|
||||
|
@ -100,23 +104,20 @@ func (br *WABridge) Init() {
|
|||
Analytics.key = br.Config.Analytics.Token
|
||||
Analytics.userID = br.Config.Analytics.UserID
|
||||
if Analytics.IsEnabled() {
|
||||
Analytics.log.Infoln("Analytics metrics are enabled")
|
||||
if Analytics.userID != "" {
|
||||
Analytics.log.Infoln("Overriding analytics user_id with %v", Analytics.userID)
|
||||
}
|
||||
Analytics.log.Info().Str("override_user_id", Analytics.userID).Msg("Analytics metrics are enabled")
|
||||
}
|
||||
|
||||
br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database"))
|
||||
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), &waLogger{br.Log.Sub("Database").Sub("WhatsApp")})
|
||||
br.DB = database.New(br.Bridge.DB)
|
||||
br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), waLog.Zerolog(br.ZLog.With().Str("db_section", "whatsmeow").Logger()))
|
||||
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
|
||||
|
||||
ss := br.Config.Bridge.Provisioning.SharedSecret
|
||||
if len(ss) > 0 && ss != "disable" {
|
||||
br.Provisioning = &ProvisioningAPI{bridge: br}
|
||||
br.Provisioning = &ProvisioningAPI{bridge: br, log: br.ZLog.With().Str("component", "provisioning").Logger()}
|
||||
}
|
||||
|
||||
br.Formatter = NewFormatter(br)
|
||||
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
|
||||
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.ZLog.With().Str("component", "metrics").Logger(), br.DB)
|
||||
br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent
|
||||
|
||||
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
|
||||
|
@ -148,11 +149,10 @@ func (br *WABridge) Init() {
|
|||
func (br *WABridge) Start() {
|
||||
err := br.WAContainer.Upgrade()
|
||||
if err != nil {
|
||||
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
|
||||
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade whatsmeow database")
|
||||
os.Exit(15)
|
||||
}
|
||||
if br.Provisioning != nil {
|
||||
br.Log.Debugln("Initializing provisioning API")
|
||||
br.Provisioning.Init()
|
||||
}
|
||||
go br.CheckWhatsAppUpdate()
|
||||
|
@ -166,30 +166,40 @@ func (br *WABridge) Start() {
|
|||
}
|
||||
|
||||
func (br *WABridge) CheckWhatsAppUpdate() {
|
||||
br.Log.Debugfln("Checking for WhatsApp web update")
|
||||
br.ZLog.Debug().Msg("Checking for WhatsApp web update")
|
||||
resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
|
||||
if err != nil {
|
||||
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
|
||||
br.ZLog.Warn().Err(err).Msg("Failed to check for WhatsApp web update")
|
||||
return
|
||||
}
|
||||
if store.GetWAVersion() == resp.ParsedVersion {
|
||||
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
|
||||
br.ZLog.Debug().Msg("Bridge is using latest WhatsApp web protocol")
|
||||
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
|
||||
if resp.IsBelowHard || resp.IsBroken {
|
||||
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.ZLog.Warn().
|
||||
Stringer("latest_version", resp.ParsedVersion).
|
||||
Stringer("current_version", store.GetWAVersion()).
|
||||
Msg("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore")
|
||||
} else if resp.IsBelowSoft {
|
||||
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.ZLog.Info().
|
||||
Stringer("latest_version", resp.ParsedVersion).
|
||||
Stringer("current_version", store.GetWAVersion()).
|
||||
Msg("Bridge is using outdated WhatsApp web protocol")
|
||||
} else {
|
||||
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
|
||||
br.ZLog.Debug().
|
||||
Stringer("latest_version", resp.ParsedVersion).
|
||||
Stringer("current_version", store.GetWAVersion()).
|
||||
Msg("Bridge is using outdated WhatsApp web protocol")
|
||||
}
|
||||
} else {
|
||||
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
|
||||
br.ZLog.Debug().Msg("Bridge is using newer than latest WhatsApp web protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func (br *WABridge) Loop() {
|
||||
ctx := br.ZLog.With().Str("action", "background loop").Logger().WithContext(context.TODO())
|
||||
for {
|
||||
br.SleepAndDeleteUpcoming()
|
||||
br.SleepAndDeleteUpcoming(ctx)
|
||||
time.Sleep(1 * time.Hour)
|
||||
br.WarnUsersAboutDisconnection()
|
||||
}
|
||||
|
@ -199,14 +209,14 @@ func (br *WABridge) WarnUsersAboutDisconnection() {
|
|||
br.usersLock.Lock()
|
||||
for _, user := range br.usersByUsername {
|
||||
if user.IsConnected() && !user.PhoneRecentlySeen(true) {
|
||||
go user.sendPhoneOfflineWarning()
|
||||
go user.sendPhoneOfflineWarning(context.TODO())
|
||||
}
|
||||
}
|
||||
br.usersLock.Unlock()
|
||||
}
|
||||
|
||||
func (br *WABridge) StartUsers() {
|
||||
br.Log.Debugln("Starting users")
|
||||
br.ZLog.Debug().Msg("Starting users")
|
||||
foundAnySessions := false
|
||||
for _, user := range br.GetAllUsers() {
|
||||
if !user.JID.IsEmpty() {
|
||||
|
@ -217,13 +227,13 @@ func (br *WABridge) StartUsers() {
|
|||
if !foundAnySessions {
|
||||
br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil))
|
||||
}
|
||||
br.Log.Debugln("Starting custom puppets")
|
||||
br.ZLog.Debug().Msg("Starting custom puppets")
|
||||
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
|
||||
go func(puppet *Puppet) {
|
||||
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
|
||||
puppet.zlog.Debug().Stringer("custom_mxid", puppet.CustomMXID).Msg("Starting double puppet")
|
||||
err := puppet.StartCustomMXID(true)
|
||||
if err != nil {
|
||||
puppet.log.Errorln("Failed to start custom puppet:", err)
|
||||
puppet.zlog.Err(err).Stringer("custom_mxid", puppet.CustomMXID).Msg("Failed to start double puppet")
|
||||
}
|
||||
}(loopuppet)
|
||||
}
|
||||
|
@ -235,7 +245,7 @@ func (br *WABridge) Stop() {
|
|||
if user.Client == nil {
|
||||
continue
|
||||
}
|
||||
br.Log.Debugln("Disconnecting", user.MXID)
|
||||
user.zlog.Debug().Msg("Disconnecting user")
|
||||
user.Client.Disconnect()
|
||||
close(user.historySyncs)
|
||||
}
|
||||
|
|
78
matrix.go
78
matrix.go
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -17,8 +17,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
|
@ -35,77 +37,89 @@ func (br *WABridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.User,
|
|||
puppet := brGhost.(*Puppet)
|
||||
key := database.NewPortalKey(puppet.JID, inviter.JID)
|
||||
portal := br.GetPortalByJID(key)
|
||||
log := br.ZLog.With().
|
||||
Str("action", "create private portal").
|
||||
Stringer("target_room_id", roomID).
|
||||
Stringer("inviter_mxid", inviter.MXID).
|
||||
Stringer("invitee_jid", puppet.JID).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.TODO())
|
||||
|
||||
if len(portal.MXID) == 0 {
|
||||
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
|
||||
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
|
||||
return
|
||||
}
|
||||
|
||||
ok := portal.ensureUserInvited(inviter)
|
||||
ok := portal.ensureUserInvited(ctx, inviter)
|
||||
if !ok {
|
||||
br.Log.Warnfln("Failed to invite %s to existing private chat portal %s with %s. Redirecting portal to new room...", inviter.MXID, portal.MXID, puppet.JID)
|
||||
br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal)
|
||||
log.Warn().Msg("Failed to invite user to existing private chat portal. Redirecting portal to new room...")
|
||||
br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal)
|
||||
return
|
||||
}
|
||||
intent := puppet.DefaultIntent()
|
||||
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%[1]s](https://matrix.to/#/%[1]s)", portal.MXID)
|
||||
errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Config.Homeserver.Domain).MatrixToURL())
|
||||
errorContent := format.RenderMarkdown(errorMessage, true, false)
|
||||
_, _ = intent.SendMessageEvent(roomID, event.EventMessage, errorContent)
|
||||
br.Log.Debugfln("Leaving private chat room %s as %s after accepting invite from %s as we already have chat with the user", roomID, puppet.MXID, inviter.MXID)
|
||||
_, _ = intent.LeaveRoom(roomID)
|
||||
_, _ = intent.SendMessageEvent(ctx, roomID, event.EventMessage, errorContent)
|
||||
log.Debug().Msg("Leaving private chat room from invite as we already have chat with the user")
|
||||
_, _ = intent.LeaveRoom(ctx, roomID)
|
||||
}
|
||||
|
||||
func (br *WABridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
|
||||
func (br *WABridge) createPrivatePortalFromInvite(ctx context.Context, roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
// TODO check if room is already encrypted
|
||||
var existingEncryption event.EncryptionEventContent
|
||||
var encryptionEnabled bool
|
||||
err := portal.MainIntent().StateEvent(roomID, event.StateEncryption, "", &existingEncryption)
|
||||
err := portal.MainIntent().StateEvent(ctx, roomID, event.StateEncryption, "", &existingEncryption)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to check if encryption is enabled in private chat room %s", roomID)
|
||||
log.Err(err).Msg("Failed to check if encryption is enabled")
|
||||
} else {
|
||||
encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1
|
||||
}
|
||||
portal.MXID = roomID
|
||||
portal.updateLogger()
|
||||
portal.Topic = PrivateChatTopic
|
||||
portal.Name = puppet.Displayname
|
||||
portal.AvatarURL = puppet.AvatarURL
|
||||
portal.Avatar = puppet.Avatar
|
||||
portal.log.Infofln("Created private chat portal in %s after invite from %s", roomID, inviter.MXID)
|
||||
log.Info().Msg("Created private chat portal from invite")
|
||||
intent := puppet.DefaultIntent()
|
||||
|
||||
if br.Config.Bridge.Encryption.Default || encryptionEnabled {
|
||||
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
|
||||
_, err = intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID})
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to invite bridge bot to enable e2be:", err)
|
||||
log.Err(err).Msg("Failed to invite bridge bot to enable e2be")
|
||||
}
|
||||
err = br.Bot.EnsureJoined(roomID)
|
||||
err = br.Bot.EnsureJoined(ctx, roomID)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to join as bridge bot to enable e2be:", err)
|
||||
log.Err(err).Msg("Failed to join as bridge bot to enable e2be")
|
||||
}
|
||||
if !encryptionEnabled {
|
||||
_, err = intent.SendStateEvent(roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
|
||||
_, err = intent.SendStateEvent(ctx, roomID, event.StateEncryption, "", portal.GetEncryptionEventContent())
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to enable e2be:", err)
|
||||
log.Err(err).Msg("Failed to enable e2be")
|
||||
}
|
||||
}
|
||||
br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin)
|
||||
br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin)
|
||||
br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin)
|
||||
br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin)
|
||||
br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin)
|
||||
br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin)
|
||||
portal.Encrypted = true
|
||||
}
|
||||
_, _ = portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic)
|
||||
_, _ = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, portal.Topic)
|
||||
if portal.shouldSetDMRoomMetadata() {
|
||||
_, err = portal.MainIntent().SetRoomName(portal.MXID, portal.Name)
|
||||
_, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, portal.Name)
|
||||
portal.NameSet = err == nil
|
||||
_, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL)
|
||||
_, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL)
|
||||
portal.AvatarSet = err == nil
|
||||
}
|
||||
portal.Update(nil)
|
||||
portal.UpdateBridgeInfo()
|
||||
_, _ = intent.SendNotice(roomID, "Private chat portal created")
|
||||
err = portal.Update(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save portal to database after creating from invite")
|
||||
}
|
||||
portal.UpdateBridgeInfo(ctx)
|
||||
_, _ = intent.SendNotice(ctx, roomID, "Private chat portal created")
|
||||
}
|
||||
|
||||
func (br *WABridge) HandlePresence(evt *event.Event) {
|
||||
func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) {
|
||||
user := br.GetUserByMXIDIfExists(evt.Sender)
|
||||
if user == nil || !user.IsLoggedIn() {
|
||||
return
|
||||
|
@ -119,15 +133,15 @@ func (br *WABridge) HandlePresence(evt *event.Event) {
|
|||
presence := types.PresenceAvailable
|
||||
if evt.Content.AsPresence().Presence != event.PresenceOnline {
|
||||
presence = types.PresenceUnavailable
|
||||
user.log.Debugln("Marking offline")
|
||||
user.zlog.Debug().Msg("Marking offline")
|
||||
} else {
|
||||
user.log.Debugln("Marking online")
|
||||
user.zlog.Debug().Msg("Marking online")
|
||||
}
|
||||
user.lastPresence = presence
|
||||
if user.Client.Store.PushName != "" {
|
||||
err := user.Client.SendPresence(presence)
|
||||
if err != nil {
|
||||
user.log.Warnln("Failed to set presence:", err)
|
||||
user.zlog.Err(err).Msg("Failed to set presence")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -23,7 +23,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
|
||||
|
@ -123,7 +123,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev
|
|||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType string, confirmed bool, editID id.EventID) id.EventID {
|
||||
func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, err error, confirmed bool, editID id.EventID) id.EventID {
|
||||
if !portal.bridge.Config.Bridge.MessageErrorNotices {
|
||||
return ""
|
||||
}
|
||||
|
@ -131,6 +131,21 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
|
|||
if confirmed {
|
||||
certainty = "was not"
|
||||
}
|
||||
var msgType string
|
||||
switch evt.Type {
|
||||
case event.EventMessage:
|
||||
msgType = "message"
|
||||
case event.EventReaction:
|
||||
msgType = "reaction"
|
||||
case event.EventRedaction:
|
||||
msgType = "redaction"
|
||||
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
|
||||
msgType = "poll response"
|
||||
case TypeMSC3381PollStart:
|
||||
msgType = "poll start"
|
||||
default:
|
||||
msgType = "unknown event"
|
||||
}
|
||||
msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err)
|
||||
if errors.Is(err, errMessageTakingLong) {
|
||||
msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType)
|
||||
|
@ -144,15 +159,15 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri
|
|||
} else {
|
||||
content.SetReply(evt)
|
||||
}
|
||||
resp, err := portal.sendMainIntentMessage(content)
|
||||
resp, err := portal.sendMainIntentMessage(ctx, content)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to send bridging error message:", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to send bridging error message")
|
||||
return ""
|
||||
}
|
||||
return resp.EventID
|
||||
}
|
||||
|
||||
func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
|
||||
func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) {
|
||||
if !portal.bridge.Config.Bridge.MessageStatusEvents {
|
||||
return
|
||||
}
|
||||
|
@ -179,75 +194,56 @@ func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, de
|
|||
content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err)
|
||||
content.Error = err.Error()
|
||||
}
|
||||
_, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content)
|
||||
_, err = intent.SendMessageEvent(ctx, portal.MXID, event.BeeperMessageStatus, &content)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to send message status event:", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to send message status event")
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) {
|
||||
func (portal *Portal) sendDeliveryReceipt(ctx context.Context, eventID id.EventID) {
|
||||
if portal.bridge.Config.Bridge.DeliveryReceipts {
|
||||
err := portal.bridge.Bot.SendReceipt(portal.MXID, eventID, event.ReceiptTypeRead, nil)
|
||||
err := portal.bridge.Bot.SendReceipt(ctx, portal.MXID, eventID, event.ReceiptTypeRead, nil)
|
||||
if err != nil {
|
||||
portal.log.Debugfln("Failed to send delivery receipt for %s: %v", eventID, err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to mark message as read by bot (Matrix-side delivery receipt)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string, ms *metricSender) {
|
||||
var msgType string
|
||||
switch evt.Type {
|
||||
case event.EventMessage:
|
||||
msgType = "message"
|
||||
case event.EventReaction:
|
||||
msgType = "reaction"
|
||||
case event.EventRedaction:
|
||||
msgType = "redaction"
|
||||
case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse:
|
||||
msgType = "poll response"
|
||||
case TypeMSC3381PollStart:
|
||||
msgType = "poll start"
|
||||
default:
|
||||
msgType = "unknown event"
|
||||
}
|
||||
evtDescription := evt.ID.String()
|
||||
if evt.Type == event.EventRedaction {
|
||||
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
|
||||
}
|
||||
func (portal *Portal) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, ms *metricSender) {
|
||||
origEvtID := evt.ID
|
||||
if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil {
|
||||
origEvtID = retryMeta.OriginalEventID
|
||||
}
|
||||
if err != nil {
|
||||
level := log.LevelError
|
||||
level := zerolog.ErrorLevel
|
||||
if part == "Ignoring" {
|
||||
level = log.LevelDebug
|
||||
level = zerolog.DebugLevel
|
||||
}
|
||||
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
|
||||
zerolog.Ctx(ctx).WithLevel(level).Err(err).Msg(part + " Matrix event")
|
||||
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
|
||||
checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode)
|
||||
portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum())
|
||||
if sendNotice {
|
||||
ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID()))
|
||||
ms.setNoticeID(portal.sendErrorMessage(ctx, evt, err, isCertain, ms.getNoticeID()))
|
||||
}
|
||||
portal.sendStatusEvent(origEvtID, evt.ID, err, nil)
|
||||
portal.sendStatusEvent(ctx, origEvtID, evt.ID, err, nil)
|
||||
} else {
|
||||
portal.log.Debugfln("Handled Matrix %s %s", msgType, evtDescription)
|
||||
portal.sendDeliveryReceipt(evt.ID)
|
||||
zerolog.Ctx(ctx).Debug().Msg("Successfully handled Matrix event")
|
||||
portal.sendDeliveryReceipt(ctx, evt.ID)
|
||||
portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum())
|
||||
var deliveredTo *[]id.UserID
|
||||
if portal.IsPrivateChat() {
|
||||
deliveredTo = &[]id.UserID{}
|
||||
}
|
||||
portal.sendStatusEvent(origEvtID, evt.ID, nil, deliveredTo)
|
||||
portal.sendStatusEvent(ctx, origEvtID, evt.ID, nil, deliveredTo)
|
||||
if prevNotice := ms.popNoticeID(); prevNotice != "" {
|
||||
_, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{
|
||||
_, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, prevNotice, mautrix.ReqRedact{
|
||||
Reason: "error resolved",
|
||||
})
|
||||
}
|
||||
}
|
||||
if ms != nil {
|
||||
portal.log.Debugfln("Timings for %s: %s", evt.ID, ms.timings.String())
|
||||
zerolog.Ctx(ctx).Debug().Object("timings", ms.timings).Msg("Matrix event timings")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,47 +260,16 @@ type messageTimings struct {
|
|||
totalSend time.Duration
|
||||
}
|
||||
|
||||
func niceRound(dur time.Duration) time.Duration {
|
||||
switch {
|
||||
case dur < time.Millisecond:
|
||||
return dur
|
||||
case dur < time.Second:
|
||||
return dur.Round(100 * time.Microsecond)
|
||||
default:
|
||||
return dur.Round(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (mt *messageTimings) String() string {
|
||||
mt.initReceive = niceRound(mt.initReceive)
|
||||
mt.decrypt = niceRound(mt.decrypt)
|
||||
mt.portalQueue = niceRound(mt.portalQueue)
|
||||
mt.totalReceive = niceRound(mt.totalReceive)
|
||||
mt.implicitRR = niceRound(mt.implicitRR)
|
||||
mt.preproc = niceRound(mt.preproc)
|
||||
mt.convert = niceRound(mt.convert)
|
||||
mt.whatsmeow.Queue = niceRound(mt.whatsmeow.Queue)
|
||||
mt.whatsmeow.Marshal = niceRound(mt.whatsmeow.Marshal)
|
||||
mt.whatsmeow.GetParticipants = niceRound(mt.whatsmeow.GetParticipants)
|
||||
mt.whatsmeow.GetDevices = niceRound(mt.whatsmeow.GetDevices)
|
||||
mt.whatsmeow.GroupEncrypt = niceRound(mt.whatsmeow.GroupEncrypt)
|
||||
mt.whatsmeow.PeerEncrypt = niceRound(mt.whatsmeow.PeerEncrypt)
|
||||
mt.whatsmeow.Send = niceRound(mt.whatsmeow.Send)
|
||||
mt.whatsmeow.Resp = niceRound(mt.whatsmeow.Resp)
|
||||
mt.whatsmeow.Retry = niceRound(mt.whatsmeow.Retry)
|
||||
mt.totalSend = niceRound(mt.totalSend)
|
||||
whatsmeowTimings := "N/A"
|
||||
if mt.totalSend > 0 {
|
||||
format := "queue: %[1]s, marshal: %[2]s, ske: %[3]s, pcp: %[4]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
|
||||
if mt.whatsmeow.GetParticipants == 0 && mt.whatsmeow.GroupEncrypt == 0 {
|
||||
format = "queue: %[1]s, marshal: %[2]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s"
|
||||
}
|
||||
if mt.whatsmeow.Retry > 0 {
|
||||
format += ", retry: %[9]s"
|
||||
}
|
||||
whatsmeowTimings = fmt.Sprintf(format, mt.whatsmeow.Queue, mt.whatsmeow.Marshal, mt.whatsmeow.GroupEncrypt, mt.whatsmeow.GetParticipants, mt.whatsmeow.GetDevices, mt.whatsmeow.PeerEncrypt, mt.whatsmeow.Send, mt.whatsmeow.Resp, mt.whatsmeow.Retry)
|
||||
}
|
||||
return fmt.Sprintf("BRIDGE: receive: %s, decrypt: %s, queue: %s, total hs->portal: %s, implicit rr: %s -- PORTAL: preprocess: %s, convert: %s, total send: %s -- WHATSMEOW: %s", mt.initReceive, mt.decrypt, mt.implicitRR, mt.portalQueue, mt.totalReceive, mt.preproc, mt.convert, mt.totalSend, whatsmeowTimings)
|
||||
func (mt *messageTimings) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Dur("init_receive", mt.initReceive).
|
||||
Dur("decrypt", mt.decrypt).
|
||||
Dur("implicit_rr", mt.implicitRR).
|
||||
Dur("portal_queue", mt.portalQueue).
|
||||
Dur("total_receive", mt.totalReceive).
|
||||
Dur("preproc", mt.preproc).
|
||||
Dur("convert", mt.convert).
|
||||
Object("whatsmeow", mt.whatsmeow).
|
||||
Dur("total_send", mt.totalSend)
|
||||
}
|
||||
|
||||
type metricSender struct {
|
||||
|
@ -345,13 +310,13 @@ func (ms *metricSender) setNoticeID(evtID id.EventID) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ms *metricSender) sendMessageMetrics(evt *event.Event, err error, part string, completed bool) {
|
||||
func (ms *metricSender) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, completed bool) {
|
||||
ms.lock.Lock()
|
||||
defer ms.lock.Unlock()
|
||||
if !completed && ms.completed {
|
||||
return
|
||||
}
|
||||
ms.portal.sendMessageMetrics(evt, err, part, ms)
|
||||
ms.portal.sendMessageMetrics(ctx, evt, err, part, ms)
|
||||
ms.retryNum++
|
||||
ms.completed = completed
|
||||
}
|
||||
|
|
34
metrics.go
34
metrics.go
|
@ -18,6 +18,7 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
|
@ -27,7 +28,7 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
|
@ -40,7 +41,7 @@ import (
|
|||
type MetricsHandler struct {
|
||||
db *database.Database
|
||||
server *http.Server
|
||||
log log.Logger
|
||||
log zerolog.Logger
|
||||
|
||||
running bool
|
||||
ctx context.Context
|
||||
|
@ -70,7 +71,7 @@ type MetricsHandler struct {
|
|||
loggedInStateLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler {
|
||||
func NewMetricsHandler(address string, log zerolog.Logger, db *database.Database) *MetricsHandler {
|
||||
portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "whatsapp_portals_total",
|
||||
Help: "Number of portal rooms on Matrix",
|
||||
|
@ -232,31 +233,31 @@ func (mh *MetricsHandler) TrackConnectionState(jid types.JID, connected bool) {
|
|||
func (mh *MetricsHandler) updateStats() {
|
||||
start := time.Now()
|
||||
var puppetCount int
|
||||
err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
|
||||
err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
|
||||
if err != nil {
|
||||
mh.log.Warnln("Failed to scan number of puppets:", err)
|
||||
mh.log.Err(err).Msg("Failed to scan number of puppets")
|
||||
} else {
|
||||
mh.puppetCount.Set(float64(puppetCount))
|
||||
}
|
||||
|
||||
var userCount int
|
||||
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
|
||||
err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
|
||||
if err != nil {
|
||||
mh.log.Warnln("Failed to scan number of users:", err)
|
||||
mh.log.Err(err).Msg("Failed to scan number of users")
|
||||
} else {
|
||||
mh.userCount.Set(float64(userCount))
|
||||
}
|
||||
|
||||
var messageCount int
|
||||
err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
|
||||
err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
|
||||
if err != nil {
|
||||
mh.log.Warnln("Failed to scan number of messages:", err)
|
||||
mh.log.Err(err).Msg("Failed to scan number of messages")
|
||||
} else {
|
||||
mh.messageCount.Set(float64(messageCount))
|
||||
}
|
||||
|
||||
var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int
|
||||
err = mh.db.QueryRowContext(mh.ctx, `
|
||||
err = mh.db.QueryRow(mh.ctx, `
|
||||
SELECT
|
||||
COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals,
|
||||
COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals,
|
||||
|
@ -265,7 +266,7 @@ func (mh *MetricsHandler) updateStats() {
|
|||
FROM portal WHERE mxid<>''
|
||||
`).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount)
|
||||
if err != nil {
|
||||
mh.log.Warnln("Failed to scan number of portals:", err)
|
||||
mh.log.Err(err).Msg("Failed to scan number of portals")
|
||||
} else {
|
||||
mh.encryptedGroupCount.Set(float64(encryptedGroupCount))
|
||||
mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount))
|
||||
|
@ -279,7 +280,10 @@ func (mh *MetricsHandler) startUpdatingStats() {
|
|||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
mh.log.Fatalfln("Panic in metric updater: %v\n%s", err, string(debug.Stack()))
|
||||
mh.log.WithLevel(zerolog.PanicLevel).
|
||||
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
|
||||
Interface(zerolog.ErrorFieldName, err).
|
||||
Msg("Panic in metric updater")
|
||||
}
|
||||
}()
|
||||
ticker := time.Tick(10 * time.Second)
|
||||
|
@ -299,8 +303,8 @@ func (mh *MetricsHandler) Start() {
|
|||
go mh.startUpdatingStats()
|
||||
err := mh.server.ListenAndServe()
|
||||
mh.running = false
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
mh.log.Fatalln("Error in metrics listener:", err)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
mh.log.Err(err).Msg("Error in metrics listener")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -311,6 +315,6 @@ func (mh *MetricsHandler) Stop() {
|
|||
mh.stopRecorder()
|
||||
err := mh.server.Close()
|
||||
if err != nil {
|
||||
mh.log.Errorln("Error closing metrics listener:", err)
|
||||
mh.log.Err(err).Msg("Failed to close metrics listener")
|
||||
}
|
||||
}
|
||||
|
|
128
provisioning.go
128
provisioning.go
|
@ -17,41 +17,40 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/beeper/libserv/pkg/requestlog"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/hlog"
|
||||
|
||||
"go.mau.fi/whatsmeow/appstate"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/bridge/status"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type ProvisioningAPI struct {
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) Init() {
|
||||
prov.log = prov.bridge.Log.Sub("Provisioning")
|
||||
|
||||
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
|
||||
prov.log.Debug().Str("base_path", prov.bridge.Config.Bridge.Provisioning.Prefix).Msg("Enabling provisioning API")
|
||||
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
|
||||
r.Use(hlog.NewHandler(prov.log))
|
||||
r.Use(requestlog.AccessLogger(true))
|
||||
r.Use(prov.AuthMiddleware)
|
||||
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
|
||||
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
|
||||
|
@ -73,7 +72,7 @@ func (prov *ProvisioningAPI) Init() {
|
|||
prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost)
|
||||
|
||||
if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints {
|
||||
prov.log.Debugln("Enabling debug API at /debug")
|
||||
prov.log.Debug().Msg("Enabling debug API at /debug")
|
||||
r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter()
|
||||
r.Use(prov.AuthMiddleware)
|
||||
r.PathPrefix("/pprof").Handler(http.DefaultServeMux)
|
||||
|
@ -83,26 +82,6 @@ func (prov *ProvisioningAPI) Init() {
|
|||
r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost)
|
||||
}
|
||||
|
||||
type responseWrap struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
var _ http.Hijacker = (*responseWrap)(nil)
|
||||
|
||||
func (rw *responseWrap) WriteHeader(statusCode int) {
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
rw.statusCode = statusCode
|
||||
}
|
||||
|
||||
func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("response does not implement http.Hijacker")
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
|
@ -119,7 +98,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
auth = auth[len("Bearer "):]
|
||||
}
|
||||
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
|
||||
prov.log.Infof("Authentication token does not match shared secret")
|
||||
hlog.FromRequest(r).Debug().Msg("Authentication token does not match shared secret")
|
||||
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
|
||||
"error": "Authentication token does not match shared secret",
|
||||
"errcode": "M_FORBIDDEN",
|
||||
|
@ -128,11 +107,12 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
}
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
user := prov.bridge.GetUserByMXID(id.UserID(userID))
|
||||
start := time.Now()
|
||||
wWrap := &responseWrap{w, 200}
|
||||
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user)))
|
||||
duration := time.Now().Sub(start).Seconds()
|
||||
prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode)
|
||||
if user != nil {
|
||||
hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Stringer("user_id", user.MXID)
|
||||
})
|
||||
}
|
||||
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -157,7 +137,7 @@ func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Reques
|
|||
return
|
||||
}
|
||||
user.DeleteConnection()
|
||||
user.DeleteSession()
|
||||
user.DeleteSession(r.Context())
|
||||
jsonResponse(w, http.StatusOK, Response{true, "Session information purged"})
|
||||
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
|
||||
}
|
||||
|
@ -245,7 +225,7 @@ func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request
|
|||
ErrCode: "no session",
|
||||
})
|
||||
} else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil {
|
||||
prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err)
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to fetch all contacts")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: "Internal server error while fetching contact list",
|
||||
ErrCode: "failed to get contacts",
|
||||
|
@ -282,7 +262,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
|
|||
if r.Method == http.MethodPost {
|
||||
err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true")
|
||||
if err != nil {
|
||||
prov.log.Errorfln("Failed to resync %s's groups: %v", user.MXID, err)
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to resync groups")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: "Internal server error while resyncing groups",
|
||||
ErrCode: "failed to sync groups",
|
||||
|
@ -291,7 +271,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
}
|
||||
if groups, err := user.getCachedGroupList(); err != nil {
|
||||
prov.log.Errorfln("Failed to fetch %s's groups: %v", user.MXID, err)
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to fetch group list")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: "Internal server error while fetching group list",
|
||||
ErrCode: "failed to get groups",
|
||||
|
@ -368,17 +348,17 @@ func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
|
|||
// resolveIdentifier already responded with an error
|
||||
return
|
||||
}
|
||||
portal, puppet, justCreated, err := user.StartPM(jid, "provisioning API PM")
|
||||
portal, puppet, justCreated, err := user.StartPM(r.Context(), jid, "provisioning API PM")
|
||||
if err != nil {
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to create portal: %v", err),
|
||||
})
|
||||
}
|
||||
status := http.StatusOK
|
||||
statusCode := http.StatusOK
|
||||
if justCreated {
|
||||
status = http.StatusCreated
|
||||
statusCode = http.StatusCreated
|
||||
}
|
||||
jsonResponse(w, status, PortalInfo{
|
||||
jsonResponse(w, statusCode, PortalInfo{
|
||||
RoomID: portal.MXID,
|
||||
OtherUser: &OtherUserInfo{
|
||||
JID: puppet.JID,
|
||||
|
@ -449,29 +429,30 @@ func (prov *ProvisioningAPI) OpenGroup(w http.ResponseWriter, r *http.Request) {
|
|||
ErrCode: "invalid group id",
|
||||
})
|
||||
} else if info, err := user.Client.GetGroupInfo(jid); err != nil {
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to get group info by JID")
|
||||
// TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup)
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to get group info: %v", err),
|
||||
ErrCode: "error getting group info",
|
||||
})
|
||||
} else {
|
||||
prov.log.Debugln("Importing", jid, "for", user.MXID)
|
||||
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Importing group chat for user")
|
||||
portal := user.GetPortalByJID(info.JID)
|
||||
status := http.StatusOK
|
||||
statusCode := http.StatusOK
|
||||
if len(portal.MXID) == 0 {
|
||||
err = portal.CreateMatrixRoom(user, info, nil, true, true)
|
||||
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
|
||||
if err != nil {
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to create portal: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
status = http.StatusCreated
|
||||
statusCode = http.StatusCreated
|
||||
}
|
||||
jsonResponse(w, status, PortalInfo{
|
||||
jsonResponse(w, statusCode, PortalInfo{
|
||||
RoomID: portal.MXID,
|
||||
GroupInfo: info,
|
||||
JustCreated: status == http.StatusCreated,
|
||||
JustCreated: statusCode == http.StatusCreated,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -495,6 +476,7 @@ func (prov *ProvisioningAPI) resolveGroupInvite(w http.ResponseWriter, r *http.R
|
|||
ErrCode: "invalid invite link",
|
||||
})
|
||||
} else {
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to get group info from link")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to fetch group info with link: %v", err),
|
||||
ErrCode: "error getting group info",
|
||||
|
@ -530,29 +512,30 @@ func (prov *ProvisioningAPI) JoinGroup(w http.ResponseWriter, r *http.Request) {
|
|||
}()
|
||||
inviteCode, _ := mux.Vars(r)["inviteCode"]
|
||||
if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil {
|
||||
hlog.FromRequest(r).Err(err).Msg("Failed to join group")
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to join group: %v", err),
|
||||
ErrCode: "error joining group",
|
||||
})
|
||||
} else {
|
||||
prov.log.Debugln(user.MXID, "successfully joined group", jid)
|
||||
hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Successfully joined group")
|
||||
portal := user.GetPortalByJID(jid)
|
||||
status := http.StatusOK
|
||||
statusCode := http.StatusOK
|
||||
if len(portal.MXID) == 0 {
|
||||
time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically
|
||||
err = portal.CreateMatrixRoom(user, info, nil, true, true)
|
||||
err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true)
|
||||
if err != nil {
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Failed to create portal: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
status = http.StatusCreated
|
||||
statusCode = http.StatusCreated
|
||||
}
|
||||
jsonResponse(w, status, PortalInfo{
|
||||
jsonResponse(w, statusCode, PortalInfo{
|
||||
RoomID: portal.MXID,
|
||||
GroupInfo: info,
|
||||
JustCreated: status == http.StatusCreated,
|
||||
JustCreated: statusCode == http.StatusCreated,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -616,7 +599,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
|
|||
} else {
|
||||
err := user.Client.Logout()
|
||||
if err != nil {
|
||||
user.log.Warnln("Error while logging out:", err)
|
||||
hlog.FromRequest(r).Err(err).Msg("Unknown error while logging out")
|
||||
if !force {
|
||||
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||
Error: fmt.Sprintf("Unknown error while logging out: %v", err),
|
||||
|
@ -632,7 +615,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
user.bridge.Metrics.TrackConnectionState(user.JID, false)
|
||||
user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut})
|
||||
user.DeleteSession()
|
||||
user.DeleteSession(r.Context())
|
||||
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
|
||||
}
|
||||
|
||||
|
@ -646,16 +629,17 @@ var upgrader = websocket.Upgrader{
|
|||
func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
user := prov.bridge.GetUserByMXID(id.UserID(userID))
|
||||
log := hlog.FromRequest(r)
|
||||
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
prov.log.Errorln("Failed to upgrade connection to websocket:", err)
|
||||
log.Err(err).Msg("Failed to upgrade connection to websocket")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
user.log.Debugln("Error closing websocket:", err)
|
||||
log.Debug().Err(err).Msg("Error closing websocket")
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -670,23 +654,26 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
}()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c.SetCloseHandler(func(code int, text string) error {
|
||||
user.log.Debugfln("Login websocket closed (%d), cancelling login", code)
|
||||
log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login")
|
||||
cancel()
|
||||
return nil
|
||||
})
|
||||
|
||||
if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" {
|
||||
user.log.Debug("Setting timezone to %s", userTimezone)
|
||||
log.Debug().Str("timezone", userTimezone).Msg("Updating user timezone")
|
||||
user.Timezone = userTimezone
|
||||
user.Update()
|
||||
err = user.Update(r.Context())
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save user after updating timezone")
|
||||
}
|
||||
} else {
|
||||
user.log.Debug("No timezone provided in request")
|
||||
log.Debug().Msg("No timezone provided in request")
|
||||
}
|
||||
|
||||
qrChan, err := user.Login(ctx)
|
||||
expiryTime := time.Now().Add(160 * time.Second)
|
||||
if err != nil {
|
||||
user.log.Errorln("Failed to log in from provisioning API:", err)
|
||||
log.Err(err).Msg("Failed to log in via provisioning API")
|
||||
if errors.Is(err, ErrAlreadyLoggedIn) {
|
||||
go user.Connect()
|
||||
_ = c.WriteJSON(Error{
|
||||
|
@ -704,7 +691,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
if phoneNum != "" {
|
||||
pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)")
|
||||
if err != nil {
|
||||
user.zlog.Err(err).Msg("Failed to start phone code login")
|
||||
log.Err(err).Msg("Failed to start phone code login")
|
||||
_ = c.WriteJSON(Error{
|
||||
Error: "Failed to request pairing code",
|
||||
ErrCode: "code error",
|
||||
|
@ -712,6 +699,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
go user.DeleteConnection()
|
||||
return
|
||||
} else {
|
||||
log.Debug().Msg("Started phone number login")
|
||||
_ = c.WriteJSON(map[string]any{
|
||||
"pairing_code": pairingCode,
|
||||
"timeout": int(time.Until(expiryTime).Seconds()),
|
||||
|
@ -719,7 +707,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
user.log.Debugln("Started login via provisioning API")
|
||||
log.Debug().Msg("Started login via provisioning API")
|
||||
Analytics.Track(user.MXID, "$login_start")
|
||||
|
||||
for {
|
||||
|
@ -728,7 +716,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
switch evt.Event {
|
||||
case whatsmeow.QRChannelSuccess.Event:
|
||||
jid := user.Client.Store.ID
|
||||
user.log.Debugln("Successful login as", jid, "via provisioning API")
|
||||
log.Debug().Stringer("jid", jid).Msg("Successful login via provisioning API")
|
||||
Analytics.Track(user.MXID, "$login_success")
|
||||
_ = c.WriteJSON(map[string]interface{}{
|
||||
"success": true,
|
||||
|
@ -737,7 +725,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
"platform": user.Client.Store.Platform,
|
||||
})
|
||||
case whatsmeow.QRChannelTimeout.Event:
|
||||
user.log.Debugln("Login via provisioning API timed out")
|
||||
log.Debug().Msg("Login via provisioning API timed out")
|
||||
errCode := "login timed out"
|
||||
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
|
||||
_ = c.WriteJSON(Error{
|
||||
|
@ -745,7 +733,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
ErrCode: errCode,
|
||||
})
|
||||
case whatsmeow.QRChannelErrUnexpectedEvent.Event:
|
||||
user.log.Debugln("Login via provisioning API failed due to unexpected event")
|
||||
log.Debug().Msg("Login via provisioning API failed due to unexpected event")
|
||||
errCode := "unexpected event"
|
||||
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
|
||||
_ = c.WriteJSON(Error{
|
||||
|
@ -753,7 +741,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||
ErrCode: errCode,
|
||||
})
|
||||
case whatsmeow.QRChannelClientOutdated.Event:
|
||||
user.log.Debugln("Login via provisioning API failed due to outdated client")
|
||||
log.Debug().Msg("Login via provisioning API failed due to outdated client")
|
||||
errCode := "bridge outdated"
|
||||
Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
|
||||
_ = c.WriteJSON(Error{
|
||||
|
|
137
puppet.go
137
puppet.go
|
@ -17,15 +17,15 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
|
@ -59,6 +59,7 @@ func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
|
|||
}
|
||||
|
||||
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
|
||||
ctx := context.TODO()
|
||||
jid = jid.ToNonAD()
|
||||
if jid.Server == types.LegacyUserServer {
|
||||
jid.Server = types.DefaultUserServer
|
||||
|
@ -69,11 +70,19 @@ func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
|
|||
defer br.puppetsLock.Unlock()
|
||||
puppet, ok := br.puppets[jid]
|
||||
if !ok {
|
||||
dbPuppet := br.DB.Puppet.Get(jid)
|
||||
dbPuppet, err := br.DB.Puppet.Get(ctx, jid)
|
||||
if err != nil {
|
||||
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get puppet from database")
|
||||
return nil
|
||||
}
|
||||
if dbPuppet == nil {
|
||||
dbPuppet = br.DB.Puppet.New()
|
||||
dbPuppet.JID = jid
|
||||
dbPuppet.Insert()
|
||||
err = dbPuppet.Insert(ctx)
|
||||
if err != nil {
|
||||
br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to insert new puppet to database")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
puppet = br.NewPuppet(dbPuppet)
|
||||
br.puppets[puppet.JID] = puppet
|
||||
|
@ -89,7 +98,10 @@ func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
|
|||
defer br.puppetsLock.Unlock()
|
||||
puppet, ok := br.puppetsByCustomMXID[mxid]
|
||||
if !ok {
|
||||
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
|
||||
dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid)
|
||||
if err != nil {
|
||||
br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get puppet by custom mxid from database")
|
||||
}
|
||||
if dbPuppet == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -137,14 +149,18 @@ func (puppet *Puppet) GetMXID() id.UserID {
|
|||
}
|
||||
|
||||
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID(context.TODO()))
|
||||
}
|
||||
|
||||
func (br *WABridge) GetAllPuppets() []*Puppet {
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
|
||||
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll(context.TODO()))
|
||||
}
|
||||
|
||||
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
|
||||
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet, err error) []*Puppet {
|
||||
if err != nil {
|
||||
br.ZLog.Err(err).Msg("Error getting puppets from database")
|
||||
return nil
|
||||
}
|
||||
br.puppetsLock.Lock()
|
||||
defer br.puppetsLock.Unlock()
|
||||
output := make([]*Puppet, len(dbPuppets))
|
||||
|
@ -175,7 +191,7 @@ func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
|
|||
return &Puppet{
|
||||
Puppet: dbPuppet,
|
||||
bridge: br,
|
||||
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
|
||||
zlog: br.ZLog.With().Stringer("puppet_jid", dbPuppet.JID).Logger(),
|
||||
|
||||
MXID: br.FormatPuppetMXID(dbPuppet.JID),
|
||||
}
|
||||
|
@ -185,7 +201,7 @@ type Puppet struct {
|
|||
*database.Puppet
|
||||
|
||||
bridge *WABridge
|
||||
log log.Logger
|
||||
zlog zerolog.Logger
|
||||
|
||||
typingIn id.RoomID
|
||||
typingAt time.Time
|
||||
|
@ -223,47 +239,47 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
|
|||
return puppet.bridge.AS.Intent(puppet.MXID)
|
||||
}
|
||||
|
||||
func (puppet *Puppet) UpdateAvatar(source *User, forcePortalSync bool) bool {
|
||||
changed := source.updateAvatar(puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.log, puppet.DefaultIntent())
|
||||
func (puppet *Puppet) UpdateAvatar(ctx context.Context, source *User, forcePortalSync bool) bool {
|
||||
changed := source.updateAvatar(ctx, puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.DefaultIntent())
|
||||
if !changed || puppet.Avatar == "unauthorized" {
|
||||
if forcePortalSync {
|
||||
go puppet.updatePortalAvatar()
|
||||
go puppet.updatePortalAvatar(ctx)
|
||||
}
|
||||
return changed
|
||||
}
|
||||
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
|
||||
err := puppet.DefaultIntent().SetAvatarURL(ctx, puppet.AvatarURL)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to set avatar:", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar from puppet")
|
||||
} else {
|
||||
puppet.AvatarSet = true
|
||||
}
|
||||
go puppet.updatePortalAvatar()
|
||||
go puppet.updatePortalAvatar(ctx)
|
||||
return true
|
||||
}
|
||||
|
||||
func (puppet *Puppet) UpdateName(contact types.ContactInfo, forcePortalSync bool) bool {
|
||||
func (puppet *Puppet) UpdateName(ctx context.Context, contact types.ContactInfo, forcePortalSync bool) bool {
|
||||
newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact)
|
||||
if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality {
|
||||
oldName := puppet.Displayname
|
||||
puppet.Displayname = newName
|
||||
puppet.NameQuality = quality
|
||||
puppet.NameSet = false
|
||||
err := puppet.DefaultIntent().SetDisplayName(newName)
|
||||
err := puppet.DefaultIntent().SetDisplayName(ctx, newName)
|
||||
if err == nil {
|
||||
puppet.log.Debugln("Updated name", oldName, "->", newName)
|
||||
puppet.zlog.Debug().Str("old_name", oldName).Str("new_name", newName).Msg("Updated name")
|
||||
puppet.NameSet = true
|
||||
go puppet.updatePortalName()
|
||||
go puppet.updatePortalName(ctx)
|
||||
} else {
|
||||
puppet.log.Warnln("Failed to set display name:", err)
|
||||
puppet.zlog.Err(err).Msg("Failed to set displayname")
|
||||
}
|
||||
return true
|
||||
} else if forcePortalSync {
|
||||
go puppet.updatePortalName()
|
||||
go puppet.updatePortalName(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (puppet *Puppet) UpdateContactInfo() bool {
|
||||
func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool {
|
||||
if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
|
||||
return false
|
||||
}
|
||||
|
@ -281,9 +297,9 @@ func (puppet *Puppet) UpdateContactInfo() bool {
|
|||
"com.beeper.bridge.service": "whatsapp",
|
||||
"com.beeper.bridge.network": "whatsapp",
|
||||
}
|
||||
err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo)
|
||||
err := puppet.DefaultIntent().BeeperUpdateProfile(ctx, contactInfo)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to store custom contact info in profile:", err)
|
||||
puppet.zlog.Err(err).Msg("Failed to store custom contact info in profile")
|
||||
return false
|
||||
} else {
|
||||
puppet.ContactInfoSet = true
|
||||
|
@ -300,7 +316,7 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
|
|||
}
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updatePortalAvatar() {
|
||||
func (puppet *Puppet) updatePortalAvatar(ctx context.Context) {
|
||||
puppet.updatePortalMeta(func(portal *Portal) {
|
||||
if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) {
|
||||
return
|
||||
|
@ -308,28 +324,31 @@ func (puppet *Puppet) updatePortalAvatar() {
|
|||
portal.AvatarURL = puppet.AvatarURL
|
||||
portal.Avatar = puppet.Avatar
|
||||
portal.AvatarSet = false
|
||||
defer portal.Update(nil)
|
||||
if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() {
|
||||
portal.UpdateBridgeInfo()
|
||||
portal.UpdateBridgeInfo(ctx)
|
||||
} else if len(portal.MXID) > 0 {
|
||||
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
|
||||
_, err := portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, puppet.AvatarURL)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to set avatar:", err)
|
||||
portal.zlog.Err(err).Msg("Failed to set avatar from puppet")
|
||||
} else {
|
||||
portal.AvatarSet = true
|
||||
portal.UpdateBridgeInfo()
|
||||
portal.UpdateBridgeInfo(ctx)
|
||||
}
|
||||
}
|
||||
err := portal.Update(ctx)
|
||||
if err != nil {
|
||||
portal.zlog.Err(err).Msg("Failed to save portal after updating avatar from puppet")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updatePortalName() {
|
||||
func (puppet *Puppet) updatePortalName(ctx context.Context) {
|
||||
puppet.updatePortalMeta(func(portal *Portal) {
|
||||
portal.UpdateName(puppet.Displayname, types.EmptyJID, true)
|
||||
portal.UpdateName(ctx, puppet.Displayname, types.EmptyJID, true)
|
||||
})
|
||||
}
|
||||
|
||||
func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
|
||||
func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoName, shouldHavePushName bool, reason string) {
|
||||
if puppet == nil {
|
||||
return
|
||||
}
|
||||
|
@ -337,39 +356,67 @@ func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName
|
|||
source.EnqueuePuppetResync(puppet)
|
||||
return
|
||||
}
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("method", "Puppet.SyncContact").
|
||||
Stringer("puppet_jid", puppet.JID).
|
||||
Stringer("source_user_jid", source.JID).
|
||||
Stringer("source_user_mxid", source.MXID).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
|
||||
contact, err := source.Client.Store.Contacts.GetContact(puppet.JID)
|
||||
if err != nil {
|
||||
puppet.log.Warnfln("Failed to get contact info through %s in SyncContact: %v (sync reason: %s)", source.MXID, reason)
|
||||
log.Err(err).
|
||||
Stringer("source_mxid", source.MXID).
|
||||
Str("sync_reason", reason).
|
||||
Msg("Failed to get contact info through user in SyncContact")
|
||||
} else if !contact.Found {
|
||||
puppet.log.Warnfln("No contact info found through %s in SyncContact (sync reason: %s)", source.MXID, reason)
|
||||
log.Warn().
|
||||
Stringer("source_mxid", source.MXID).
|
||||
Str("sync_reason", reason).
|
||||
Msg("No contact info found through user in SyncContact")
|
||||
}
|
||||
puppet.Sync(source, &contact, false, false)
|
||||
puppet.syncInternal(ctx, source, &contact, false, false)
|
||||
}
|
||||
|
||||
func (puppet *Puppet) Sync(source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
|
||||
func (puppet *Puppet) Sync(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("method", "Puppet.Sync").
|
||||
Stringer("puppet_jid", puppet.JID).
|
||||
Stringer("source_user_jid", source.JID).
|
||||
Stringer("source_user_mxid", source.MXID).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
puppet.syncInternal(ctx, source, contact, forceAvatarSync, forcePortalSync)
|
||||
}
|
||||
|
||||
func (puppet *Puppet) syncInternal(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
puppet.syncLock.Lock()
|
||||
defer puppet.syncLock.Unlock()
|
||||
err := puppet.DefaultIntent().EnsureRegistered()
|
||||
err := puppet.DefaultIntent().EnsureRegistered(ctx)
|
||||
if err != nil {
|
||||
puppet.log.Errorln("Failed to ensure registered:", err)
|
||||
log.Err(err).Msg("Failed to ensure registered")
|
||||
}
|
||||
|
||||
puppet.log.Debugfln("Syncing info through %s", source.JID)
|
||||
log.Debug().Stringer("source_jid", source.JID).Msg("Syncing info through user")
|
||||
|
||||
update := false
|
||||
if contact != nil {
|
||||
if puppet.JID.User == source.JID.User {
|
||||
contact.PushName = source.Client.Store.PushName
|
||||
}
|
||||
update = puppet.UpdateName(*contact, forcePortalSync) || update
|
||||
update = puppet.UpdateName(ctx, *contact, forcePortalSync) || update
|
||||
}
|
||||
if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync {
|
||||
update = puppet.UpdateAvatar(source, forcePortalSync) || update
|
||||
update = puppet.UpdateAvatar(ctx, source, forcePortalSync) || update
|
||||
}
|
||||
update = puppet.UpdateContactInfo() || update
|
||||
update = puppet.UpdateContactInfo(ctx) || update
|
||||
if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) {
|
||||
puppet.LastSync = time.Now()
|
||||
puppet.Update()
|
||||
err = puppet.Update(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save puppet after sync")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
// Copyright (C) 2024 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -19,7 +19,6 @@ package main
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"image"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -27,33 +26,26 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/idna"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"go.mau.fi/whatsmeow"
|
||||
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/event"
|
||||
)
|
||||
|
||||
type BeeperLinkPreview struct {
|
||||
mautrix.RespPreviewURL
|
||||
MatchedURL string `json:"matched_url"`
|
||||
ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
|
||||
}
|
||||
|
||||
func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*BeeperLinkPreview {
|
||||
func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*event.BeeperLinkPreview {
|
||||
if msg.GetMatchedText() == "" {
|
||||
return []*BeeperLinkPreview{}
|
||||
return []*event.BeeperLinkPreview{}
|
||||
}
|
||||
|
||||
output := &BeeperLinkPreview{
|
||||
output := &event.BeeperLinkPreview{
|
||||
MatchedURL: msg.GetMatchedText(),
|
||||
RespPreviewURL: mautrix.RespPreviewURL{
|
||||
LinkPreview: event.LinkPreview{
|
||||
CanonicalURL: msg.GetCanonicalUrl(),
|
||||
Title: msg.GetTitle(),
|
||||
Description: msg.GetDescription(),
|
||||
|
@ -68,7 +60,7 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
|
|||
var err error
|
||||
thumbnailData, err = source.Client.DownloadThumbnail(msg)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to download thumbnail for link preview: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to download thumbnail for link preview")
|
||||
}
|
||||
}
|
||||
if thumbnailData == nil && msg.JpegThumbnail != nil {
|
||||
|
@ -93,9 +85,9 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
|
|||
uploadMime = "application/octet-stream"
|
||||
output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto}
|
||||
}
|
||||
resp, err := intent.UploadBytes(uploadData, uploadMime)
|
||||
resp, err := intent.UploadBytes(ctx, uploadData, uploadMime)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to reupload thumbnail for link preview: %v", err)
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload thumbnail for link preview")
|
||||
} else {
|
||||
if output.ImageEncryption != nil {
|
||||
output.ImageEncryption.URL = resp.ContentURI.CUString()
|
||||
|
@ -108,35 +100,36 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so
|
|||
output.Type = "video.other"
|
||||
}
|
||||
|
||||
return []*BeeperLinkPreview{output}
|
||||
return []*event.BeeperLinkPreview{output}
|
||||
}
|
||||
|
||||
var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`)
|
||||
|
||||
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, evt *event.Event, dest *waProto.ExtendedTextMessage) bool {
|
||||
var preview *BeeperLinkPreview
|
||||
func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, content *event.MessageEventContent, dest *waProto.ExtendedTextMessage) bool {
|
||||
log := zerolog.Ctx(ctx)
|
||||
var preview *event.BeeperLinkPreview
|
||||
|
||||
rawPreview := gjson.GetBytes(evt.Content.VeryRaw, `com\.beeper\.linkpreviews`)
|
||||
if rawPreview.Exists() && rawPreview.IsArray() {
|
||||
var previews []BeeperLinkPreview
|
||||
if err := json.Unmarshal([]byte(rawPreview.Raw), &previews); err != nil || len(previews) == 0 {
|
||||
if content.BeeperLinkPreviews != nil {
|
||||
// Note: this check explicitly happens after checking for nil: empty arrays are treated as no previews,
|
||||
// but omitting the field means the bridge may look for URLs in the message text.
|
||||
if len(content.BeeperLinkPreviews) == 0 {
|
||||
return false
|
||||
}
|
||||
// WhatsApp only supports a single preview.
|
||||
preview = &previews[0]
|
||||
preview = content.BeeperLinkPreviews[0]
|
||||
} else if portal.bridge.Config.Bridge.URLPreviews {
|
||||
if matchedURL := URLRegex.FindString(evt.Content.AsMessage().Body); len(matchedURL) == 0 {
|
||||
if matchedURL := URLRegex.FindString(content.Body); len(matchedURL) == 0 {
|
||||
return false
|
||||
} else if parsed, err := url.Parse(matchedURL); err != nil {
|
||||
return false
|
||||
} else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
|
||||
return false
|
||||
} else if mxPreview, err := portal.MainIntent().GetURLPreview(parsed.String()); err != nil {
|
||||
portal.log.Warnfln("Failed to fetch preview for %s: %v", matchedURL, err)
|
||||
} else if mxPreview, err := portal.MainIntent().GetURLPreview(ctx, parsed.String()); err != nil {
|
||||
log.Err(err).Str("url", matchedURL).Msg("Failed to fetch URL preview")
|
||||
return false
|
||||
} else {
|
||||
preview = &BeeperLinkPreview{
|
||||
RespPreviewURL: *mxPreview,
|
||||
preview = &event.BeeperLinkPreview{
|
||||
LinkPreview: *mxPreview,
|
||||
MatchedURL: matchedURL,
|
||||
}
|
||||
}
|
||||
|
@ -163,22 +156,22 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
|
|||
imageMXC = preview.ImageEncryption.URL.ParseOrIgnore()
|
||||
}
|
||||
if !imageMXC.IsEmpty() {
|
||||
data, err := portal.MainIntent().DownloadBytesContext(ctx, imageMXC)
|
||||
data, err := portal.MainIntent().DownloadBytes(ctx, imageMXC)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Failed to download URL preview image %s in %s: %v", preview.ImageURL, evt.ID, err)
|
||||
log.Err(err).Str("image_url", string(preview.ImageURL)).Msg("Failed to download URL preview image")
|
||||
return true
|
||||
}
|
||||
if preview.ImageEncryption != nil {
|
||||
err = preview.ImageEncryption.DecryptInPlace(data)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Failed to decrypt URL preview image in %s: %v", evt.ID, err)
|
||||
log.Err(err).Msg("Failed to decrypt URL preview image")
|
||||
return true
|
||||
}
|
||||
}
|
||||
dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix())
|
||||
uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail)
|
||||
if err != nil {
|
||||
portal.log.Errorfln("Failed to upload URL preview thumbnail in %s: %v", evt.ID, err)
|
||||
log.Err(err).Msg("Failed to reupload URL preview thumbnail")
|
||||
return true
|
||||
}
|
||||
dest.ThumbnailSha256 = uploadResp.FileSHA256
|
||||
|
@ -188,7 +181,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U
|
|||
var width, height int
|
||||
dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false)
|
||||
if err != nil {
|
||||
portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err)
|
||||
log.Err(err).Msg("Failed to create JPEG thumbnail for URL preview")
|
||||
}
|
||||
if preview.ImageHeight > 0 && preview.ImageWidth > 0 {
|
||||
dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth))
|
||||
|
|
Loading…
Reference in a new issue